diff --git a/driver-core/src/main/com/mongodb/connection/AsyncCompletionHandler.java b/driver-core/src/main/com/mongodb/connection/AsyncCompletionHandler.java index 893c5f0eedf..a286f346427 100644 --- a/driver-core/src/main/com/mongodb/connection/AsyncCompletionHandler.java +++ b/driver-core/src/main/com/mongodb/connection/AsyncCompletionHandler.java @@ -16,6 +16,7 @@ package com.mongodb.connection; +import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.Nullable; /** @@ -38,4 +39,17 @@ public interface AsyncCompletionHandler { * @param t the exception that describes the failure */ void failed(Throwable t); + + /** + * @return this handler as a callback + */ + default SingleResultCallback asCallback() { + return (r, t) -> { + if (t != null) { + failed(t); + } else { + completed(r); + } + }; + } } diff --git a/driver-core/src/main/com/mongodb/connection/Stream.java b/driver-core/src/main/com/mongodb/connection/Stream.java index 9c8a3a03d20..1edf41ab742 100644 --- a/driver-core/src/main/com/mongodb/connection/Stream.java +++ b/driver-core/src/main/com/mongodb/connection/Stream.java @@ -17,6 +17,7 @@ package com.mongodb.connection; import com.mongodb.ServerAddress; +import com.mongodb.internal.async.SingleResultCallback; import org.bson.ByteBuf; import java.io.IOException; @@ -43,7 +44,7 @@ public interface Stream extends BufferProvider{ * * @param handler the completion handler for opening the stream */ - void openAsync(AsyncCompletionHandler handler); + void openAsync(SingleResultCallback handler); /** * Write each buffer in the list to the stream in order, blocking until all are completely written. diff --git a/driver-core/src/main/com/mongodb/connection/TlsChannelStreamFactoryFactory.java b/driver-core/src/main/com/mongodb/connection/TlsChannelStreamFactoryFactory.java index 90bc987272f..9fb5e59ee3a 100644 --- a/driver-core/src/main/com/mongodb/connection/TlsChannelStreamFactoryFactory.java +++ b/driver-core/src/main/com/mongodb/connection/TlsChannelStreamFactoryFactory.java @@ -19,6 +19,7 @@ import com.mongodb.MongoClientException; import com.mongodb.MongoSocketOpenException; import com.mongodb.ServerAddress; +import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.connection.AsynchronousChannelStream; import com.mongodb.internal.connection.ExtendedAsynchronousByteChannel; import com.mongodb.internal.connection.PowerOfTwoBufferPool; @@ -201,7 +202,8 @@ public boolean supportsAdditionalTimeout() { @SuppressWarnings("deprecation") @Override - public void openAsync(final AsyncCompletionHandler handler) { + public void openAsync(final SingleResultCallback callback) { + AsyncCompletionHandler handler = callback.asHandler(); isTrue("unopened", getChannel() == null); try { SocketChannel socketChannel = SocketChannel.open(); diff --git a/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java b/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java index bb971603ab5..37b4bbc2fb3 100644 --- a/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java +++ b/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java @@ -28,6 +28,7 @@ import com.mongodb.connection.SocketSettings; import com.mongodb.connection.SslSettings; import com.mongodb.connection.Stream; +import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.connection.netty.NettyByteBuf; import com.mongodb.lang.Nullable; import io.netty.bootstrap.Bootstrap; @@ -158,13 +159,14 @@ public ByteBuf getBuffer(final int size) { @Override public void open() throws IOException { FutureAsyncCompletionHandler handler = new FutureAsyncCompletionHandler<>(); - openAsync(handler); + openAsync(handler.asCallback()); handler.get(); } @SuppressWarnings("deprecation") @Override - public void openAsync(final AsyncCompletionHandler handler) { + public void openAsync(final SingleResultCallback callback) { + AsyncCompletionHandler handler = callback.asHandler(); Queue socketAddressQueue; try { diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java b/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java new file mode 100644 index 00000000000..b385670ae88 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java @@ -0,0 +1,26 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +/** + * See tests for usage (AsyncFunctionsTest). + *

+ * This class is not part of the public API and may be removed or changed at any time + */ +@FunctionalInterface +public interface AsyncConsumer extends AsyncFunction { +} diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java new file mode 100644 index 00000000000..76dbccc5081 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java @@ -0,0 +1,31 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +/** + * See tests for usage (AsyncFunctionsTest). + *

+ * This class is not part of the public API and may be removed or changed at any time + */ +@FunctionalInterface +public interface AsyncFunction { + /** + * This should not be called externally, but should be implemented as a + * lambda. To "finish" an async chain, use one of the "finish" methods. + */ + void unsafeFinish(T value, SingleResultCallback callback); +} diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java new file mode 100644 index 00000000000..b9089252f49 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -0,0 +1,158 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import com.mongodb.internal.async.function.RetryState; +import com.mongodb.internal.async.function.RetryingAsyncCallbackSupplier; + +import java.util.function.Predicate; +import java.util.function.Supplier; + +/** + * See tests for usage (AsyncFunctionsTest). + *

+ * This class is not part of the public API and may be removed or changed at any time + */ +@FunctionalInterface +public interface AsyncRunnable extends AsyncSupplier, AsyncConsumer { + + static AsyncRunnable beginAsync() { + return (c) -> c.onResult(null, null); + } + + /** + * Must be invoked at end of async chain + * @param runnable the sync code to invoke (under non-exceptional flow) + * prior to the callback + * @param callback the callback provided by the method the chain is used in + */ + default void thenRunAndFinish(final Runnable runnable, final SingleResultCallback callback) { + this.finish((r, e) -> { + if (e != null) { + callback.onResult(null, e); + return; + } + try { + runnable.run(); + } catch (Throwable t) { + callback.onResult(null, t); + return; + } + callback.onResult(null, null); + }); + } + + /** + * See {@link #thenRunAndFinish(Runnable, SingleResultCallback)}, but the runnable + * will always be executed, including on the exceptional path. + * @param runnable the runnable + * @param callback the callback + */ + default void thenAlwaysRunAndFinish(final Runnable runnable, final SingleResultCallback callback) { + this.finish((r, e) -> { + try { + runnable.run(); + } catch (Throwable t) { + if (e != null) { + t.addSuppressed(e); + } + callback.onResult(null, t); + return; + } + callback.onResult(null, e); + }); + } + + /** + * @param runnable The async runnable to run after this runnable + * @return the composition of this runnable and the runnable, a runnable + */ + default AsyncRunnable thenRun(final AsyncRunnable runnable) { + return (c) -> { + this.unsafeFinish((r, e) -> { + if (e == null) { + runnable.unsafeFinish(c); + } else { + c.onResult(null, e); + } + }); + }; + } + + /** + * @param condition the condition to check + * @param runnable The async runnable to run after this runnable, + * if and only if the condition is met + * @return the composition of this runnable and the runnable, a runnable + */ + default AsyncRunnable thenRunIf(final Supplier condition, final AsyncRunnable runnable) { + return (callback) -> { + this.unsafeFinish((r, e) -> { + if (e != null) { + callback.onResult(null, e); + return; + } + boolean matched; + try { + matched = condition.get(); + } catch (Throwable t) { + callback.onResult(null, t); + return; + } + if (matched) { + runnable.unsafeFinish(callback); + } else { + callback.onResult(null, null); + } + }); + }; + } + + /** + * @param supplier The supplier to supply using after this runnable + * @return the composition of this runnable and the supplier, a supplier + * @param The return type of the resulting supplier + */ + default AsyncSupplier thenSupply(final AsyncSupplier supplier) { + return (c) -> { + this.unsafeFinish((r, e) -> { + if (e == null) { + supplier.unsafeFinish(c); + } else { + c.onResult(null, e); + } + }); + }; + } + + /** + * @param runnable the runnable to loop + * @param shouldRetry condition under which to retry + * @return the composition of this, and the looping branch + * @see RetryingAsyncCallbackSupplier + */ + default AsyncRunnable thenRunRetryingWhile( + final AsyncRunnable runnable, final Predicate shouldRetry) { + return thenRun(callback -> { + new RetryingAsyncCallbackSupplier( + new RetryState(), + (rs, lastAttemptFailure) -> shouldRetry.test(lastAttemptFailure), + cb -> runnable.finish(cb) // finish is required here, to handle exceptions + ).get(callback); + }); + } +} diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java new file mode 100644 index 00000000000..ede848eb344 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -0,0 +1,137 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import java.util.function.Predicate; + + +/** + * See tests for usage (AsyncFunctionsTest). + *

+ * This class is not part of the public API and may be removed or changed at any time + */ +@FunctionalInterface +public interface AsyncSupplier extends AsyncFunction { + /** + * This should not be called externally to this API. It should be + * implemented as a lambda. To "finish" an async chain, use one of + * the "finish" methods. + * + * @see #finish(SingleResultCallback) + */ + void unsafeFinish(SingleResultCallback callback); + + /** + * This is the async variant of a supplier's get method. + * This method must only be used when this AsyncSupplier corresponds + * to a {@link java.util.function.Supplier} (and is therefore being + * used within an async chain method lambda). + * @param callback the callback + */ + default void getAsync(final SingleResultCallback callback) { + unsafeFinish(callback); + } + + @Override + default void unsafeFinish(final Void value, final SingleResultCallback callback) { + unsafeFinish(callback); + } + + /** + * Must be invoked at end of async chain. + * @param callback the callback provided by the method the chain is used in + */ + default void finish(final SingleResultCallback callback) { + final boolean[] callbackInvoked = {false}; + try { + this.unsafeFinish((v, e) -> { + callbackInvoked[0] = true; + callback.onResult(v, e); + }); + } catch (Throwable t) { + if (callbackInvoked[0]) { + throw t; + } else { + callback.onResult(null, t); + } + } + } + + /** + * @param function The async function to run after this supplier + * @return the composition of this supplier and the function, a supplier + * @param The return type of the resulting supplier + */ + default AsyncSupplier thenApply(final AsyncFunction function) { + return (c) -> { + this.unsafeFinish((v, e) -> { + if (e == null) { + function.unsafeFinish(v, c); + } else { + c.onResult(null, e); + } + }); + }; + } + + + /** + * @param consumer The async consumer to run after this supplier + * @return the composition of this supplier and the consumer, a runnable + */ + default AsyncRunnable thenConsume(final AsyncConsumer consumer) { + return (c) -> { + this.unsafeFinish((v, e) -> { + if (e == null) { + consumer.unsafeFinish(v, c); + } else { + c.onResult(null, e); + } + }); + }; + } + + /** + * @param errorCheck A check, comparable to a catch-if/otherwise-rethrow + * @param errorFunction The branch to execute if the error matches + * @return The composition of this, and the conditional branch + */ + default AsyncSupplier onErrorIf( + final Predicate errorCheck, + final AsyncFunction errorFunction) { + return (callback) -> this.finish((r, e) -> { + if (e == null) { + callback.onResult(r, null); + return; + } + boolean errorMatched; + try { + errorMatched = errorCheck.test(e); + } catch (Throwable t) { + t.addSuppressed(e); + callback.onResult(null, t); + return; + } + if (errorMatched) { + errorFunction.unsafeFinish(e, callback); + } else { + callback.onResult(null, e); + } + }); + } + +} diff --git a/driver-core/src/main/com/mongodb/internal/async/SingleResultCallback.java b/driver-core/src/main/com/mongodb/internal/async/SingleResultCallback.java index 573c1ba423c..224dae62179 100644 --- a/driver-core/src/main/com/mongodb/internal/async/SingleResultCallback.java +++ b/driver-core/src/main/com/mongodb/internal/async/SingleResultCallback.java @@ -16,6 +16,8 @@ package com.mongodb.internal.async; +import com.mongodb.assertions.Assertions; +import com.mongodb.connection.AsyncCompletionHandler; import com.mongodb.internal.async.function.AsyncCallbackFunction; import com.mongodb.lang.Nullable; @@ -34,4 +36,32 @@ public interface SingleResultCallback { * @throws Error Never, on the best effort basis. */ void onResult(@Nullable T result, @Nullable Throwable t); + + /** + * @return this callback as a handler + */ + default AsyncCompletionHandler asHandler() { + return new AsyncCompletionHandler() { + @Override + public void completed(@Nullable final T result) { + onResult(result, null); + } + @Override + public void failed(final Throwable t) { + onResult(null, t); + } + }; + } + + default void complete(final SingleResultCallback callback) { + // takes a void callback (itself) to help ensure that this method + // is not accidentally used when "complete(T)" should have been used + // instead, since results are not marked nullable. + Assertions.assertTrue(callback == this); + this.onResult(null, null); + } + + default void complete(final T result) { + this.onResult(result, null); + } } diff --git a/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackRunnable.java b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackRunnable.java index 7304a9ef9b5..02fdbdf9699 100644 --- a/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackRunnable.java @@ -32,17 +32,4 @@ public interface AsyncCallbackRunnable { */ void run(SingleResultCallback callback); - /** - * Converts this {@link AsyncCallbackSupplier} to {@link AsyncCallbackSupplier}{@code }. - */ - default AsyncCallbackSupplier asSupplier() { - return this::run; - } - - /** - * @see AsyncCallbackSupplier#whenComplete(Runnable) - */ - default AsyncCallbackRunnable whenComplete(final Runnable after) { - return callback -> asSupplier().whenComplete(after).get(callback); - } } diff --git a/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java b/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java index 9ebe02f5aa7..92233a072be 100644 --- a/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java @@ -84,6 +84,13 @@ public RetryingAsyncCallbackSupplier( this.asyncFunction = asyncFunction; } + public RetryingAsyncCallbackSupplier( + final RetryState state, + final BiPredicate retryPredicate, + final AsyncCallbackSupplier asyncFunction) { + this(state, (previouslyChosenFailure, lastAttemptFailure) -> lastAttemptFailure, retryPredicate, asyncFunction); + } + @Override public void get(final SingleResultCallback callback) { /* `asyncFunction` and `callback` are the only externally provided pieces of code for which we do not need to care about diff --git a/driver-core/src/main/com/mongodb/internal/connection/AsynchronousChannelStream.java b/driver-core/src/main/com/mongodb/internal/connection/AsynchronousChannelStream.java index bb0d5953bfb..8999bdfbe43 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/AsynchronousChannelStream.java +++ b/driver-core/src/main/com/mongodb/internal/connection/AsynchronousChannelStream.java @@ -127,7 +127,7 @@ private void readAsync(final int numBytes, final int additionalTimeout, final As @Override public void open() throws IOException { FutureAsyncCompletionHandler handler = new FutureAsyncCompletionHandler<>(); - openAsync(handler); + openAsync(handler.asCallback()); handler.getOpen(); } diff --git a/driver-core/src/main/com/mongodb/internal/connection/AsynchronousSocketChannelStream.java b/driver-core/src/main/com/mongodb/internal/connection/AsynchronousSocketChannelStream.java index 6a956247ed3..f5cd746718c 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/AsynchronousSocketChannelStream.java +++ b/driver-core/src/main/com/mongodb/internal/connection/AsynchronousSocketChannelStream.java @@ -21,6 +21,7 @@ import com.mongodb.ServerAddress; import com.mongodb.connection.AsyncCompletionHandler; import com.mongodb.connection.SocketSettings; +import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.Nullable; import java.io.IOException; @@ -56,7 +57,9 @@ public AsynchronousSocketChannelStream(final ServerAddress serverAddress, final @SuppressWarnings("deprecation") @Override - public void openAsync(final AsyncCompletionHandler handler) { + public void openAsync(final SingleResultCallback callback) { + AsyncCompletionHandler handler = callback.asHandler(); + isTrue("unopened", getChannel() == null); Queue socketAddressQueue; diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index cfeeece6126..979160022e0 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -72,6 +72,7 @@ import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.isTrue; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback; import static com.mongodb.internal.connection.CommandHelper.HELLO; import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO; @@ -205,15 +206,16 @@ public int getGeneration() { @Override public void open() { - isTrue("Open already called", stream == null); - stream = streamFactory.create(getServerAddressWithResolver()); try { + isTrue("Open already called", stream == null); + stream = streamFactory.create(getServerAddressWithResolver()); stream.open(); InternalConnectionInitializationDescription initializationDescription = connectionInitializer.startHandshake(this); - initAfterHandshakeStart(initializationDescription); + initAfterHandshakeStart(initializationDescription); initializationDescription = connectionInitializer.finishHandshake(this, initializationDescription); + initAfterHandshakeFinish(initializationDescription); } catch (Throwable t) { close(); @@ -227,45 +229,25 @@ public void open() { @Override public void openAsync(final SingleResultCallback callback) { - isTrue("Open already called", stream == null, callback); - try { + beginAsync().thenRun(c -> { + isTrue("Open already called", stream == null, callback); stream = streamFactory.create(getServerAddressWithResolver()); - stream.openAsync(new AsyncCompletionHandler() { - @Override - public void completed(@Nullable final Void aVoid) { - connectionInitializer.startHandshakeAsync(InternalStreamConnection.this, - (initialResult, initialException) -> { - if (initialException != null) { - close(); - callback.onResult(null, initialException); - } else { - assertNotNull(initialResult); - initAfterHandshakeStart(initialResult); - connectionInitializer.finishHandshakeAsync(InternalStreamConnection.this, - initialResult, (completedResult, completedException) -> { - if (completedException != null) { - close(); - callback.onResult(null, completedException); - } else { - assertNotNull(completedResult); - initAfterHandshakeFinish(completedResult); - callback.onResult(null, null); - } - }); - } - }); - } - - @Override - public void failed(final Throwable t) { - close(); - callback.onResult(null, t); - } - }); - } catch (Throwable t) { + stream.openAsync(c); + }).thenSupply(c -> { + connectionInitializer.startHandshakeAsync(this, c); + }).thenApply((initializationDescription, c) -> { + initAfterHandshakeStart(initializationDescription); + connectionInitializer.finishHandshakeAsync(this, initializationDescription, c); + }).thenConsume((initializationDescription, c) -> { + initAfterHandshakeFinish(initializationDescription); + }).onErrorIf(t -> true, (t, c) -> { close(); - callback.onResult(null, t); - } + if (t instanceof MongoException) { + throw (MongoException) t; + } else { + throw new MongoException(t.toString(), t); + } + }).finish(callback); } private ServerAddress getServerAddressWithResolver() { @@ -336,7 +318,7 @@ private Compressor createCompressor(final MongoCompressor mongoCompressor) { public void close() { // All but the first call is a no-op if (!isClosed.getAndSet(true) && (stream != null)) { - stream.close(); + stream.close(); } } @@ -352,8 +334,9 @@ public boolean isClosed() { @Nullable @Override - public T sendAndReceive(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, - final RequestContext requestContext, final OperationContext operationContext) { + public T sendAndReceive(final CommandMessage message, + final Decoder decoder, final SessionContext sessionContext, + final RequestContext requestContext, final OperationContext operationContext) { CommandEventSender commandEventSender; try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(this)) { @@ -476,8 +459,10 @@ private T receiveCommandMessageResponse(final Decoder decoder, } @Override - public void sendAndReceiveAsync(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, - final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback callback) { + public void sendAndReceiveAsync(final CommandMessage message, + final Decoder decoder, final SessionContext sessionContext, + final RequestContext requestContext, final OperationContext operationContext, + final SingleResultCallback callback) { notNull("stream is open", stream, callback); if (isClosed()) { @@ -575,11 +560,9 @@ private T getCommandResult(final Decoder decoder, final ResponseBuffers r @Override public void sendMessage(final List byteBuffers, final int lastRequestId) { notNull("stream is open", stream); - if (isClosed()) { throw new MongoSocketClosedException("Cannot write to a closed stream", getServerAddress()); } - try { stream.write(byteBuffers); } catch (Exception e) { @@ -609,14 +592,14 @@ private ResponseBuffers receiveMessageWithAdditionalTimeout(final int additional @Override public void sendMessageAsync(final List byteBuffers, final int lastRequestId, final SingleResultCallback callback) { - notNull("stream is open", stream, callback); - - if (isClosed()) { - callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); - return; - } + beginAsync().thenRun(c -> { + notNull("stream is open", stream, callback); + if (isClosed()) { + throw new MongoSocketClosedException("Can not read from a closed socket", getServerAddress()); + } - writeAsync(byteBuffers, errorHandlingCallback(callback, LOGGER)); + writeAsync(byteBuffers, errorHandlingCallback(callback, LOGGER)); + }).finish(callback); } private void writeAsync(final List byteBuffers, final SingleResultCallback callback) { diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java index ffd0b912233..03bd4d29f2b 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java @@ -37,6 +37,7 @@ import java.util.List; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static com.mongodb.internal.connection.CommandHelper.HELLO; import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO; import static com.mongodb.internal.connection.CommandHelper.executeCommand; @@ -73,9 +74,18 @@ public InternalStreamConnectionInitializer(final ClusterConnectionMode clusterCo @Override public InternalConnectionInitializationDescription startHandshake(final InternalConnection internalConnection) { - notNull("internalConnection", internalConnection); + long startTime = System.nanoTime(); - return initializeConnectionDescription(internalConnection); + notNull("internalConnection", internalConnection); + BsonDocument helloCommandDocument = createHelloCommand(authenticator, internalConnection); + BsonDocument helloResult; + try { + helloResult = executeCommand("admin", helloCommandDocument, clusterConnectionMode, serverApi, internalConnection); + } catch (MongoException e) { + throw mapHelloException(e); + } + setSpeculativeAuthenticateResponse(helloResult); + return createInitializationDescription(helloResult, internalConnection, startTime); } public InternalConnectionInitializationDescription finishHandshake(final InternalConnection internalConnection, @@ -91,15 +101,19 @@ public InternalConnectionInitializationDescription finishHandshake(final Interna public void startHandshakeAsync(final InternalConnection internalConnection, final SingleResultCallback callback) { long startTime = System.nanoTime(); - executeCommandAsync("admin", createHelloCommand(authenticator, internalConnection), clusterConnectionMode, serverApi, - internalConnection, (helloResult, t) -> { - if (t != null) { - callback.onResult(null, t instanceof MongoException ? mapHelloException((MongoException) t) : t); - } else { - setSpeculativeAuthenticateResponse(helloResult); - callback.onResult(createInitializationDescription(helloResult, internalConnection, startTime), null); - } - }); + beginAsync().thenSupply(c -> { + notNull("internalConnection", internalConnection); + BsonDocument helloCommandDocument = createHelloCommand(authenticator, internalConnection); + + beginAsync().thenSupply(c2 -> { + executeCommandAsync("admin", helloCommandDocument, clusterConnectionMode, serverApi, internalConnection, c2); + }).onErrorIf(e -> e instanceof MongoException, (t, c2) -> { + throw mapHelloException((MongoException) t); + }).thenApply((helloResult, c2) -> { + setSpeculativeAuthenticateResponse(helloResult); + c2.complete(createInitializationDescription(helloResult, internalConnection, startTime)); + }).finish(c); + }).finish(callback); } @Override @@ -121,20 +135,6 @@ public void finishHandshakeAsync(final InternalConnection internalConnection, } } - private InternalConnectionInitializationDescription initializeConnectionDescription(final InternalConnection internalConnection) { - BsonDocument helloResult; - BsonDocument helloCommandDocument = createHelloCommand(authenticator, internalConnection); - - long start = System.nanoTime(); - try { - helloResult = executeCommand("admin", helloCommandDocument, clusterConnectionMode, serverApi, internalConnection); - } catch (MongoException e) { - throw mapHelloException(e); - } - setSpeculativeAuthenticateResponse(helloResult); - return createInitializationDescription(helloResult, internalConnection, start); - } - private MongoException mapHelloException(final MongoException e) { if (checkSaslSupportedMechs && e.getCode() == USER_NOT_FOUND_CODE) { MongoCredential credential = authenticator.getMongoCredential(); diff --git a/driver-core/src/main/com/mongodb/internal/connection/SocketStream.java b/driver-core/src/main/com/mongodb/internal/connection/SocketStream.java index 03580cc7c89..57a28772ca9 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SocketStream.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SocketStream.java @@ -26,6 +26,7 @@ import com.mongodb.connection.SocketSettings; import com.mongodb.connection.SslSettings; import com.mongodb.connection.Stream; +import com.mongodb.internal.async.SingleResultCallback; import org.bson.ByteBuf; import javax.net.SocketFactory; @@ -207,7 +208,7 @@ public ByteBuf read(final int numBytes, final int additionalTimeout) throws IOEx } @Override - public void openAsync(final AsyncCompletionHandler handler) { + public void openAsync(final SingleResultCallback callback) { throw new UnsupportedOperationException(getClass() + " does not support asynchronous operations."); } diff --git a/driver-core/src/test/functional/com/mongodb/client/TestListener.java b/driver-core/src/test/functional/com/mongodb/client/TestListener.java new file mode 100644 index 00000000000..db68065432c --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/client/TestListener.java @@ -0,0 +1,43 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client; + +import com.mongodb.annotations.ThreadSafe; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * A simple listener that consumes string events, which can be checked in tests. + */ +@ThreadSafe +public final class TestListener { + private final List events = Collections.synchronizedList(new ArrayList<>()); + + public void add(final String s) { + events.add(s); + } + + public List getEventStrings() { + return new ArrayList<>(events); + } + + public void clear() { + events.clear(); + } +} diff --git a/driver-core/src/test/functional/com/mongodb/connection/netty/NettyStreamSpecification.groovy b/driver-core/src/test/functional/com/mongodb/connection/netty/NettyStreamSpecification.groovy index 6628dfb5625..bc4dbfb6c51 100644 --- a/driver-core/src/test/functional/com/mongodb/connection/netty/NettyStreamSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/connection/netty/NettyStreamSpecification.groovy @@ -76,7 +76,7 @@ class NettyStreamSpecification extends Specification { def stream = new NettyStreamFactory(SocketSettings.builder().connectTimeout(1000, TimeUnit.MILLISECONDS).build(), SslSettings.builder().build()).create(serverAddress) - def callback = new CallbackErrorHolder() + def callback = new CallbackErrorHolder().asCallback() when: stream.openAsync(callback) diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/AsyncSocketChannelStreamSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/connection/AsyncSocketChannelStreamSpecification.groovy index add5413f911..614ec42373a 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/AsyncSocketChannelStreamSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/AsyncSocketChannelStreamSpecification.groovy @@ -83,7 +83,7 @@ class AsyncSocketChannelStreamSpecification extends Specification { def stream = new AsynchronousSocketChannelStream(serverAddress, SocketSettings.builder().connectTimeout(100, MILLISECONDS).build(), new PowerOfTwoBufferPool(), null) - def callback = new CallbackErrorHolder() + def callback = new CallbackErrorHolder().asCallback() when: stream.openAsync(callback) diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java new file mode 100644 index 00000000000..8c2c4d20b35 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java @@ -0,0 +1,1131 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.internal.async; + +import com.mongodb.client.TestListener; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +final class AsyncFunctionsTest { + private final TestListener listener = new TestListener(); + private final InvocationTracker invocationTracker = new InvocationTracker(); + private boolean throwExceptionsFromAsync = false; + + @Test + void testBasicVariations1() { + /* + Some of our methods have "Async" counterparts. These "Async" methods + must implement the same behaviour asynchronously. In these "Async" + methods, a SingleResultCallback is provided as a parameter, and the + method calls at least one other "Async" method (or it invokes a + non-driver async API). + + The API tested here facilitates the writing of such methods using + standardized, tested, and non-nested boilerplate. For example, given + the following "sync" method: + + public T myMethod() + sync(1); + } + + The async counterpart would be: + + public void myMethodAsync(SingleResultCallback callback) + beginAsync().thenRun(c -> { // 1, 2 + async(1, c); // 3, 4 + }).finish(callback); // 5 + } + + Usage: + 1. Start an async chain using the "beginAsync" static method. + 2. Use an appropriate chaining method (then...), which will provide "c" + 3. copy all sync code into that method; convert sync methods to async + 4. at any async method, pass in "c", and end that "block" + 5. provide the original "callback" at the end of the chain via "finish" + + (The above example is tested at the end of this method, and other tests + will provide additional examples.) + + Requirements and conventions: + + Each async lambda MUST invoke its async method with "c", and MUST return + immediately after invoking that method. It MUST NOT, for example, have + a catch or finally (including close on try-with-resources) after the + invocation of the async method. + + In cases where the async method has "mixed" returns (some of which are + plain sync, some async), the "c" callback MUST be completed on the + plain sync path, `c.complete()`, followed by a return or end of method. + + Chains starting with "beginAsync" correspond roughly to code blocks. + This includes the method body, blocks used in if/try/catch/while/etc. + statements, and places where anonymous code blocks might be used. For + clarity, such nested/indented chains might be omitted (where possible, + as demonstrated in the tests/examples below). + + Plain sync code MAY throw exceptions, and SHOULD NOT attempt to handle + them asynchronously. The exceptions will be caught and handled by the + code blocks that contain this sync code. + + All code, including "plain" code (parameter checks) SHOULD be placed + within the "boilerplate". This ensures that exceptions are handled, + and facilitates comparison/review. This excludes code that must be + "shared", such as lambda and variable declarations. + + A curly-braced lambda body (with no linebreak before "."), as shown + below, SHOULD be used (for consistency, and ease of comparison/review). + */ + + // the number of expected variations is often: 1 + N methods invoked + // 1 variation with no exceptions, and N per an exception in each method + assertBehavesSameVariations(2, + () -> { + // single sync method invocations... + sync(1); + }, + (callback) -> { + // ...become a single async invocation, wrapped in begin-thenRun/finish: + beginAsync().thenRun(c -> { + async(1, c); + }).finish(callback); + }); + /* + Code review checklist for async code: + + 1. Is everything inside the boilerplate? + 2. Is "callback" supplied to "finish"? + 3. In each block and nested block, is that same block's "c" always passed/completed at the end of execution? + 4. Is every c.complete followed by a return, to end execution? + 5. Have all sync method calls been converted to async, where needed? + */ + } + + @Test + void testBasicVariations2() { + // tests pairs + // converting: plain-sync, sync-plain, sync-sync + // (plain-plain does not need an async chain) + + assertBehavesSameVariations(3, + () -> { + // plain (unaffected) invocations... + plain(1); + sync(2); + }, + (callback) -> { + beginAsync().thenRun(c -> { + // ...are preserved above affected methods + plain(1); + async(2, c); + }).finish(callback); + }); + + assertBehavesSameVariations(3, + () -> { + // when a plain invocation follows an affected method... + sync(1); + plain(2); + }, + (callback) -> { + // ...it is moved to its own block, and must be completed: + beginAsync().thenRun(c -> { + async(1, c); + }).thenRun(c -> { + plain(2); + c.complete(c); + }).finish(callback); + }); + + assertBehavesSameVariations(3, + () -> { + // when an affected method follows an affected method + sync(1); + sync(2); + }, + (callback) -> { + // ...it is moved to its own block + beginAsync().thenRun(c -> { + async(1, c); + }).thenRun(c -> { + async(2, c); + }).finish(callback); + }); + } + + @Test + void testBasicVariations4() { + // tests the sync-sync pair with preceding and ensuing plain methods: + assertBehavesSameVariations(5, + () -> { + plain(11); + sync(1); + plain(22); + sync(2); + }, + (callback) -> { + beginAsync().thenRun(c -> { + plain(11); + async(1, c); + }).thenRun(c -> { + plain(22); + async(2, c); + }).finish(callback); + }); + + assertBehavesSameVariations(5, + () -> { + sync(1); + plain(11); + sync(2); + plain(22); + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(1, c); + }).thenRun(c -> { + plain(11); + async(2, c); + }).thenRunAndFinish(() ->{ + plain(22); + }, callback); + }); + } + + @Test + void testSupply() { + assertBehavesSameVariations(4, + () -> { + sync(0); + plain(1); + return syncReturns(2); + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(0, c); + }).thenSupply(c -> { + plain(1); + asyncReturns(2, c); + }).finish(callback); + }); + } + + @Test + void testSupplyMixed() { + assertBehavesSameVariations(5, + () -> { + if (plainTest(1)) { + return syncReturns(11); + } else { + return plainReturns(22); + } + }, + (callback) -> { + beginAsync().thenSupply(c -> { + if (plainTest(1)) { + asyncReturns(11, c); + } else { + int r = plainReturns(22); + c.complete(r); // corresponds to a return, and + // must be followed by a return or end of method + } + }).finish(callback); + }); + } + + @SuppressWarnings("ConstantConditions") + @Test + void testFullChain() { + // tests a chain with: runnable, producer, function, function, consumer + assertBehavesSameVariations(14, + () -> { + plain(90); + sync(0); + plain(91); + sync(1); + plain(92); + int v = syncReturns(2); + plain(93); + v = syncReturns(v + 1); + plain(94); + v = syncReturns(v + 10); + plain(95); + sync(v + 100); + plain(96); + }, + (callback) -> { + beginAsync().thenRun(c -> { + plain(90); + async(0, c); + }).thenRun(c -> { + plain(91); + async(1, c); + }).thenSupply(c -> { + plain(92); + asyncReturns(2, c); + }).thenApply((v, c) -> { + plain(93); + asyncReturns(v + 1, c); + }).thenApply((v, c) -> { + plain(94); + asyncReturns(v + 10, c); + }).thenConsume((v, c) -> { + plain(95); + async(v + 100, c); + }).thenRunAndFinish(() -> { + plain(96); + }, callback); + }); + } + + @Test + void testConditionalVariations() { + assertBehavesSameVariations(5, + () -> { + if (plainTest(1)) { + sync(2); + } else { + sync(3); + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + if (plainTest(1)) { + async(2, c); + } else { + async(3, c); + } + }).finish(callback); + }); + + // 2 : fail on first sync, fail on test + // 3 : true test, sync2, sync3 + // 2 : false test, sync3 + // 7 total + assertBehavesSameVariations(7, + () -> { + sync(0); + if (plainTest(1)) { + sync(2); + } + sync(3); + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(0, c); + }).thenRunIf(() -> plainTest(1), c -> { + async(2, c); + }).thenRun(c -> { + async(3, c); + }).finish(callback); + }); + + // an additional affected method within the "if" branch + assertBehavesSameVariations(8, + () -> { + sync(0); + if (plainTest(1)) { + sync(21); + sync(22); + } + sync(3); + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(0, c); + }).thenRunIf(() -> plainTest(1), + beginAsync().thenRun(c -> { + async(21, c); + }).thenRun((c) -> { + async(22, c); + }) + ).thenRun(c -> { + async(3, c); + }).finish(callback); + }); + } + + @Test + void testMixedConditionalCascade() { + assertBehavesSameVariations(9, + () -> { + boolean test1 = plainTest(1); + if (test1) { + return syncReturns(11); + } + boolean test2 = plainTest(2); + if (test2) { + return 22; + } + int x = syncReturns(33); + plain(x + 100); + return syncReturns(44); + }, + (callback) -> { + beginAsync().thenSupply(c -> { + boolean test1 = plainTest(1); + if (test1) { + asyncReturns(11, c); + return; + } + boolean test2 = plainTest(2); + if (test2) { + c.complete(22); + return; + } + beginAsync().thenSupply(c2 -> { + asyncReturns(33, c2); + }).thenApply((x, c2) -> { + plain(assertNotNull(x) + 100); + asyncReturns(44, c2); + }).finish(c); + }).finish(callback); + }); + } + + @Test + void testPlain() { + // for completeness; should not be used, since there is no async + assertBehavesSameVariations(2, + () -> { + plain(1); + }, + (callback) -> { + beginAsync().thenRun(c -> { + plain(1); + c.complete(c); + }).finish(callback); + }); + } + + @Test + void testTryCatch() { + // single method in both try and catch + assertBehavesSameVariations(3, + () -> { + try { + sync(1); + } catch (Throwable t) { + sync(2); + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(1, c); + }).onErrorIf(t -> true, (t, c) -> { + async(2, c); + }).finish(callback); + }); + + // mixed sync/plain + assertBehavesSameVariations(3, + () -> { + try { + sync(1); + } catch (Throwable t) { + plain(2); + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(1, c); + }).onErrorIf(t -> true, (t, c) -> { + plain(2); + c.complete(c); + }).finish(callback); + }); + + // chain of 2 in try + // "onErrorIf" will consider everything in + // the preceding chain to be part of the try + assertBehavesSameVariations(5, + () -> { + try { + sync(1); + sync(2); + } catch (Throwable t) { + sync(9); + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(1, c); + }).thenRun(c -> { + async(2, c); + }).onErrorIf(t -> true, (t, c) -> { + async(9, c); + }).finish(callback); + }); + + // chain of 2 in catch + assertBehavesSameVariations(4, + () -> { + try { + sync(1); + } catch (Throwable t) { + sync(8); + sync(9); + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(1, c); + }).onErrorIf(t -> true, (t, callback2) -> { + beginAsync().thenRun(c -> { + async(8, c); + }).thenRun(c -> { + async(9, c); + }).finish(callback2); + }).finish(callback); + }); + + // method after the try-catch block + // here, the try-catch must be nested (as a code block) + assertBehavesSameVariations(5, + () -> { + try { + sync(1); + } catch (Throwable t) { + sync(2); + } + sync(3); + }, + (callback) -> { + beginAsync().thenRun(c2 -> { + beginAsync().thenRun(c -> { + async(1, c); + }).onErrorIf(t -> true, (t, c) -> { + async(2, c); + }).finish(c2); + }).thenRun(c -> { + async(3, c); + }).finish(callback); + }); + + // multiple catch blocks + // WARNING: these are not exclusive; if multiple "onErrorIf" blocks + // match, they will all be executed. + assertBehavesSameVariations(5, + () -> { + try { + if (plainTest(1)) { + throw new UnsupportedOperationException("A"); + } else { + throw new IllegalStateException("B"); + } + } catch (UnsupportedOperationException t) { + sync(8); + } catch (IllegalStateException t) { + sync(9); + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + if (plainTest(1)) { + throw new UnsupportedOperationException("A"); + } else { + throw new IllegalStateException("B"); + } + }).onErrorIf(t -> t instanceof UnsupportedOperationException, (t, c) -> { + async(8, c); + }).onErrorIf(t -> t instanceof IllegalStateException, (t, c) -> { + async(9, c); + }).finish(callback); + }); + } + + @Test + void testTryCatchWithVariables() { + // using supply etc. + assertBehavesSameVariations(12, + () -> { + try { + int i = plainTest(0) ? 1 : 2; + i = syncReturns(i + 10); + sync(i + 100); + } catch (Throwable t) { + sync(3); + } + }, + (callback) -> { + beginAsync().thenRun( + beginAsync().thenSupply(c -> { + int i = plainTest(0) ? 1 : 2; + asyncReturns(i + 10, c); + }).thenConsume((i, c) -> { + async(assertNotNull(i) + 100, c); + }) + ).onErrorIf(t -> true, (t, c) -> { + async(3, c); + }).finish(callback); + }); + + // using an externally-declared variable + assertBehavesSameVariations(17, + () -> { + int i = plainTest(0) ? 1 : 2; + try { + i = syncReturns(i + 10); + sync(i + 100); + } catch (Throwable t) { + sync(3); + } + sync(i + 1000); + }, + (callback) -> { + final int[] i = new int[1]; + beginAsync().thenRun(c -> { + i[0] = plainTest(0) ? 1 : 2; + c.complete(c); + }).thenRun(c -> { + beginAsync().thenSupply(c2 -> { + asyncReturns(i[0] + 10, c2); + }).thenConsume((i2, c2) -> { + i[0] = assertNotNull(i2); + async(i2 + 100, c2); + }).onErrorIf(t -> true, (t, c2) -> { + async(3, c2); + }).finish(c); + }).thenRun(c -> { + async(i[0] + 1000, c); + }).finish(callback); + }); + } + + @Test + void testTryCatchWithConditionInCatch() { + assertBehavesSameVariations(12, + () -> { + try { + sync(plainTest(0) ? 1 : 2); + sync(3); + } catch (Throwable t) { + sync(5); + if (t.getMessage().equals("exception-1")) { + throw t; + } else { + throw new RuntimeException("wrapped-" + t.getMessage(), t); + } + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(plainTest(0) ? 1 : 2, c); + }).thenRun(c -> { + async(3, c); + }).onErrorIf(t -> true, (t, c) -> { + beginAsync().thenRun(c2 -> { + async(5, c2); + }).thenRun(c2 -> { + if (assertNotNull(t).getMessage().equals("exception-1")) { + throw (RuntimeException) t; + } else { + throw new RuntimeException("wrapped-" + t.getMessage(), t); + } + }).finish(c); + }).finish(callback); + }); + } + + @Test + void testTryCatchTestAndRethrow() { + // thenSupply: + assertBehavesSameVariations(5, + () -> { + try { + return syncReturns(1); + } catch (Exception e) { + if (e.getMessage().equals(plainTest(1) ? "unexpected" : "exception-1")) { + return syncReturns(2); + } else { + throw e; + } + } + }, + (callback) -> { + beginAsync().thenSupply(c -> { + asyncReturns(1, c); + }).onErrorIf(e -> e.getMessage().equals(plainTest(1) ? "unexpected" : "exception-1"), (t, c) -> { + asyncReturns(2, c); + }).finish(callback); + }); + + // thenRun: + assertBehavesSameVariations(5, + () -> { + try { + sync(1); + } catch (Exception e) { + if (e.getMessage().equals(plainTest(1) ? "unexpected" : "exception-1")) { + sync(2); + } else { + throw e; + } + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(1, c); + }).onErrorIf(e -> e.getMessage().equals(plainTest(1) ? "unexpected" : "exception-1"), (t, c) -> { + async(2, c); + }).finish(callback); + }); + } + + @Test + void testLoop() { + assertBehavesSameVariations(InvocationTracker.DEPTH_LIMIT * 2 + 1, + () -> { + while (true) { + try { + sync(plainTest(0) ? 1 : 2); + } catch (RuntimeException e) { + if (e.getMessage().equals("exception-1")) { + continue; + } + throw e; + } + break; + } + }, + (callback) -> { + beginAsync().thenRunRetryingWhile( + c -> sync(plainTest(0) ? 1 : 2), + e -> e.getMessage().equals("exception-1") + ).finish(callback); + }); + } + + @Test + void testFinally() { + // (in try: normal flow + exception + exception) * (in finally: normal + exception) = 6 + assertBehavesSameVariations(6, + () -> { + try { + plain(1); + sync(2); + } finally { + plain(3); + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + plain(1); + async(2, c); + }).thenAlwaysRunAndFinish(() -> { + plain(3); + }, callback); + }); + } + + @Test + void testUsedAsLambda() { + assertBehavesSameVariations(4, + () -> { + Supplier s = () -> syncReturns(9); + sync(0); + plain(1); + return s.get(); + }, + (callback) -> { + AsyncSupplier s = (c) -> asyncReturns(9, c); + beginAsync().thenRun(c -> { + async(0, c); + }).thenSupply((c) -> { + plain(1); + s.getAsync(c); + }).finish(callback); + }); + } + + @Test + void testVariables() { + assertBehavesSameVariations(3, + () -> { + int something; + something = 90; + sync(something); + something = something + 10; + sync(something); + }, + (callback) -> { + // Certain variables may need to be shared; these can be + // declared (but not initialized) outside the async chain. + // Any container works (atomic allowed but not needed) + final int[] something = new int[1]; + beginAsync().thenRun(c -> { + something[0] = 90; + async(something[0], c); + }).thenRun((c) -> { + something[0] = something[0] + 10; + async(something[0], c); + }).finish(callback); + }); + } + + @Test + void testInvalid() { + assertThrows(IllegalStateException.class, () -> { + beginAsync().thenRun(c -> { + async(3, c); + throw new IllegalStateException("must not cause second callback invocation"); + }).finish((v, e) -> {}); + }); + assertThrows(IllegalStateException.class, () -> { + beginAsync().thenRun(c -> { + async(3, c); + }).finish((v, e) -> { + throw new IllegalStateException("must not cause second callback invocation"); + }); + }); + } + + @Test + void testDerivation() { + // Demonstrates the progression from nested async to the API. + + // Stand-ins for sync-async methods; these "happily" do not throw + // exceptions, to avoid complicating this demo async code. + Consumer happySync = (i) -> { + invocationTracker.getNextOption(1); + listener.add("affected-success-" + i); + }; + BiConsumer> happyAsync = (i, c) -> { + happySync.accept(i); + c.complete(c); + }; + + // Standard nested async, no error handling: + assertBehavesSameVariations(1, + () -> { + happySync.accept(1); + happySync.accept(2); + }, + (callback) -> { + happyAsync.accept(1, (v, e) -> { + happyAsync.accept(2, callback); + }); + }); + + // When both methods are naively extracted, they are out of order: + assertBehavesSameVariations(1, + () -> { + happySync.accept(1); + happySync.accept(2); + }, + (callback) -> { + SingleResultCallback second = (v, e) -> { + happyAsync.accept(2, callback); + }; + SingleResultCallback first = (v, e) -> { + happyAsync.accept(1, second); + }; + first.onResult(null, null); + }); + + // We create an "AsyncRunnable" that takes a callback, which + // decouples any async methods from each other, allowing them + // to be declared in a sync-like order, and without nesting: + assertBehavesSameVariations(1, + () -> { + happySync.accept(1); + happySync.accept(2); + }, + (callback) -> { + AsyncRunnable first = (SingleResultCallback c) -> { + happyAsync.accept(1, c); + }; + AsyncRunnable second = (SingleResultCallback c) -> { + happyAsync.accept(2, c); + }; + // This is a simplified variant of the "then" methods; + // it has no error handling. It takes methods A and B, + // and returns C, which is B(A()). + AsyncRunnable combined = (c) -> { + first.unsafeFinish((r, e) -> { + second.unsafeFinish(c); + }); + }; + combined.unsafeFinish(callback); + }); + + // This combining method is added as a default method on AsyncRunnable, + // and a "finish" method wraps the resulting methods. This also adds + // exception handling and monadic short-circuiting of ensuing methods + // when an exception arises (comparable to how thrown exceptions "skip" + // ensuing code). + assertBehavesSameVariations(3, + () -> { + sync(1); + sync(2); + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(1, c); + }).thenRun(c -> { + async(2, c); + }).finish(callback); + }); + } + + // invoked methods: + + private void plain(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("plain-exception-" + i); + throw new RuntimeException("affected method exception-" + i); + } else { + listener.add("plain-success-" + i); + } + } + + private int plainReturns(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("plain-exception-" + i); + throw new RuntimeException("affected method exception-" + i); + } else { + listener.add("plain-success-" + i); + return i; + } + } + + private boolean plainTest(final int i) { + int cur = invocationTracker.getNextOption(3); + if (cur == 0) { + listener.add("plain-exception-" + i); + throw new RuntimeException("affected method exception-" + i); + } else if (cur == 1) { + listener.add("plain-false-" + i); + return false; + } else { + listener.add("plain-true-" + i); + return true; + } + } + + private void sync(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("affected-exception-" + i); + throw new RuntimeException("exception-" + i); + } else { + listener.add("affected-success-" + i); + } + } + + private Integer syncReturns(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("affected-exception-" + i); + throw new RuntimeException("exception-" + i); + } else { + listener.add("affected-success-" + i); + return i; + } + } + + private void async(final int i, final SingleResultCallback callback) { + if (throwExceptionsFromAsync) { + sync(i); + callback.complete(callback); + + } else { + try { + sync(i); + callback.complete(callback); + } catch (Throwable t) { + callback.onResult(null, t); + } + } + } + + private void asyncReturns(final int i, final SingleResultCallback callback) { + if (throwExceptionsFromAsync) { + callback.complete(syncReturns(i)); + } else { + try { + callback.complete(syncReturns(i)); + } catch (Throwable t) { + callback.onResult(null, t); + } + } + } + + // assert methods: + + private void assertBehavesSameVariations(final int expectedVariations, final Runnable sync, + final Consumer> async) { + assertBehavesSameVariations(expectedVariations, + () -> { + sync.run(); + return null; + }, + (c) -> { + async.accept((v, e) -> c.onResult(v, e)); + }); + } + + private void assertBehavesSameVariations(final int expectedVariations, final Supplier sync, + final Consumer> async) { + // run the variation-trying code twice, with direct/indirect exceptions + for (int i = 0; i < 2; i++) { + throwExceptionsFromAsync = i == 0; + + // the variation-trying code: + invocationTracker.reset(); + do { + invocationTracker.startInitialStep(); + assertBehavesSame( + sync, + () -> invocationTracker.startMatchStep(), + async); + } while (invocationTracker.countDown()); + assertEquals(expectedVariations, invocationTracker.getVariationCount(), + "number of variations did not match"); + } + + } + + private void assertBehavesSame(final Supplier sync, final Runnable between, final Consumer> async) { + + T expectedValue = null; + Throwable expectedException = null; + try { + expectedValue = sync.get(); + } catch (Throwable e) { + expectedException = e; + } + List expectedEvents = listener.getEventStrings(); + + listener.clear(); + between.run(); + + AtomicReference actualValue = new AtomicReference<>(); + AtomicReference actualException = new AtomicReference<>(); + AtomicBoolean wasCalled = new AtomicBoolean(false); + try { + async.accept((v, e) -> { + actualValue.set(v); + actualException.set(e); + wasCalled.set(true); + }); + } catch (Throwable e) { + fail("async threw instead of using callback"); + } + + // The following code can be used to debug variations: +// System.out.println("===START"); +// System.out.println("sync: " + expectedEvents); +// System.out.println("callback called?: " + wasCalled.get()); +// System.out.println("value -- sync: " + expectedValue + " -- async: " + actualValue.get()); +// System.out.println("excep -- sync: " + expectedException + " -- async: " + actualException.get()); +// System.out.println("variant: " + (throwExceptionsFromAsync +// ? "exceptions thrown directly" : "exceptions into callbacks")); +// System.out.println("===END"); + + assertTrue(wasCalled.get(), "callback should have been called"); + assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched"); + assertEquals(expectedValue, actualValue.get()); + assertEquals(expectedException == null, actualException.get() == null, + "both or neither should have produced an exception"); + if (expectedException != null) { + assertEquals(expectedException.getMessage(), actualException.get().getMessage()); + assertEquals(expectedException.getClass(), actualException.get().getClass()); + } + + listener.clear(); + } + + /** + * Tracks invocations: allows testing of all variations of a method calls + */ + private static class InvocationTracker { + public static final int DEPTH_LIMIT = 50; + private final List invocationResults = new ArrayList<>(); + private boolean isMatchStep = false; // vs initial step + private int item = 0; + private int variationCount = 0; + + public void reset() { + variationCount = 0; + } + + public void startInitialStep() { + variationCount++; + isMatchStep = false; + item = -1; + } + + public int getNextOption(final int myOptionsSize) { + item++; + if (item >= invocationResults.size()) { + if (isMatchStep) { + fail("result should have been pre-initialized: steps may not match"); + } + if (isWithinDepthLimit()) { + invocationResults.add(myOptionsSize - 1); + } else { + invocationResults.add(0); // choose "0" option, usually an exception + } + } + return invocationResults.get(item); + } + + public void startMatchStep() { + isMatchStep = true; + item = -1; + } + + private boolean countDown() { + while (!invocationResults.isEmpty()) { + int lastItemIndex = invocationResults.size() - 1; + int lastItem = invocationResults.get(lastItemIndex); + if (lastItem > 0) { + // count current digit down by 1, until 0 + invocationResults.set(lastItemIndex, lastItem - 1); + return true; + } else { + // current digit completed, remove (move left) + invocationResults.remove(lastItemIndex); + } + } + return false; + } + + public int getVariationCount() { + return variationCount; + } + + public boolean isWithinDepthLimit() { + return invocationResults.size() < DEPTH_LIMIT; + } + } +} diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionBinding.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionBinding.java index 46fa37bf8d2..a1c127439c8 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionBinding.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionBinding.java @@ -22,8 +22,8 @@ import com.mongodb.ServerApi; import com.mongodb.connection.ClusterType; import com.mongodb.connection.ServerDescription; +import com.mongodb.internal.async.AsyncSupplier; import com.mongodb.internal.async.SingleResultCallback; -import com.mongodb.internal.async.function.AsyncCallbackSupplier; import com.mongodb.internal.binding.AbstractReferenceCounted; import com.mongodb.internal.binding.AsyncClusterAwareReadWriteBinding; import com.mongodb.internal.binding.AsyncConnectionSource; @@ -41,6 +41,7 @@ import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.connection.ClusterType.LOAD_BALANCED; import static com.mongodb.connection.ClusterType.SHARDED; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; /** *

This class is not part of the public API and may be removed or changed at any time

@@ -92,40 +93,38 @@ public void getReadConnectionSource(final SingleResultCallback callback) { - getConnectionSource(wrappedConnectionSourceCallback -> - wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, wrappedConnectionSourceCallback), - callback); + AsyncSupplier supplier = callback2 -> + wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, callback2); + getConnectionSource(supplier, callback); } public void getWriteConnectionSource(final SingleResultCallback callback) { getConnectionSource(wrapped::getWriteConnectionSource, callback); } - private void getConnectionSource(final AsyncCallbackSupplier connectionSourceSupplier, + private void getConnectionSource(final AsyncSupplier connectionSourceSupplier, final SingleResultCallback callback) { - WrappingCallback wrappingCallback = new WrappingCallback(callback); - - if (!session.hasActiveTransaction()) { - connectionSourceSupplier.get(wrappingCallback); - return; - } - if (TransactionContext.get(session) == null) { - connectionSourceSupplier.get((source, t) -> { - if (t != null) { - wrappingCallback.onResult(null, t); - } else { + beginAsync().thenSupply(c -> { + if (!session.hasActiveTransaction()) { + connectionSourceSupplier.getAsync(c); + } else if (TransactionContext.get(session) != null) { + wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress()), c); + } else { + beginAsync().thenSupply(c2 -> { + connectionSourceSupplier.getAsync(c2); + }).thenApply((source, c2) -> { ClusterType clusterType = assertNotNull(source).getServerDescription().getClusterType(); if (clusterType == SHARDED || clusterType == LOAD_BALANCED) { TransactionContext transactionContext = new TransactionContext<>(clusterType); session.setTransactionContext(source.getServerDescription().getAddress(), transactionContext); transactionContext.release(); // The session is responsible for retaining a reference to the context } - wrappingCallback.onResult(source, null); - } - }); - } else { - wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress()), wrappingCallback); - } + c2.complete(source); + }).finish(c); + } + }).thenApply((source, c) -> { + c.complete(new SessionBindingAsyncConnectionSource(source)); + }).finish(callback); } @Override @@ -187,24 +186,24 @@ public ReadPreference getReadPreference() { @Override public void getConnection(final SingleResultCallback callback) { - TransactionContext transactionContext = TransactionContext.get(session); - if (transactionContext != null && transactionContext.isConnectionPinningRequired()) { - AsyncConnection pinnedConnection = transactionContext.getPinnedConnection(); - if (pinnedConnection == null) { - wrapped.getConnection((connection, t) -> { - if (t != null) { - callback.onResult(null, t); - } else { - transactionContext.pinConnection(assertNotNull(connection), AsyncConnection::markAsPinned); - callback.onResult(connection, null); - } - }); - } else { - callback.onResult(pinnedConnection.retain(), null); + beginAsync().thenSupply(c -> { + TransactionContext transactionContext = TransactionContext.get(session); + if (transactionContext == null || !transactionContext.isConnectionPinningRequired()) { + wrapped.getConnection(c); + return; } - } else { - wrapped.getConnection(callback); - } + AsyncConnection pinnedAsyncConnection = transactionContext.getPinnedConnection(); + if (pinnedAsyncConnection != null) { + c.complete(pinnedAsyncConnection.retain()); + return; + } + beginAsync().thenSupply(c2 -> { + wrapped.getConnection(c2); + }).thenApply((connection, c2) -> { + transactionContext.pinConnection(connection, AsyncConnection::markAsPinned); + c2.complete(connection); + }).finish(c); + }).finish(callback); } @Override @@ -281,21 +280,4 @@ public ReadConcern getReadConcern() { } } } - - private class WrappingCallback implements SingleResultCallback { - private final SingleResultCallback callback; - - WrappingCallback(final SingleResultCallback callback) { - this.callback = callback; - } - - @Override - public void onResult(@Nullable final AsyncConnectionSource result, @Nullable final Throwable t) { - if (t != null) { - callback.onResult(null, t); - } else { - callback.onResult(new SessionBindingAsyncConnectionSource(assertNotNull(result)), null); - } - } - } } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java index b01b63d4a64..340da88e67f 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java @@ -85,7 +85,7 @@ public void failed(final Throwable t) { stream.close(); sink.error(t); } - }); + }.asCallback()); }).onErrorMap(this::unWrapException); } diff --git a/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java b/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java index a265ca01a7d..faa476b8262 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java +++ b/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java @@ -89,17 +89,18 @@ public int release() { @Override public ConnectionSource getReadConnectionSource() { - return new SessionBindingConnectionSource(getConnectionSource(wrapped::getReadConnectionSource)); + return getConnectionSource(wrapped::getReadConnectionSource); } @Override public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference) { - return new SessionBindingConnectionSource(getConnectionSource(() -> - wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference))); + Supplier supplier = () -> + wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference); + return getConnectionSource(supplier); } public ConnectionSource getWriteConnectionSource() { - return new SessionBindingConnectionSource(getConnectionSource(wrapped::getWriteConnectionSource)); + return getConnectionSource(wrapped::getWriteConnectionSource); } @Override @@ -123,23 +124,22 @@ public OperationContext getOperationContext() { return wrapped.getOperationContext(); } - private ConnectionSource getConnectionSource(final Supplier wrappedConnectionSourceSupplier) { + private ConnectionSource getConnectionSource(final Supplier connectionSourceSupplier) { + ConnectionSource source; if (!session.hasActiveTransaction()) { - return wrappedConnectionSourceSupplier.get(); - } - - if (TransactionContext.get(session) == null) { - ConnectionSource source = wrappedConnectionSourceSupplier.get(); + source = connectionSourceSupplier.get(); + } else if (TransactionContext.get(session) != null) { + source = wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress())); + } else { + source = connectionSourceSupplier.get(); ClusterType clusterType = source.getServerDescription().getClusterType(); if (clusterType == SHARDED || clusterType == LOAD_BALANCED) { TransactionContext transactionContext = new TransactionContext<>(clusterType); session.setTransactionContext(source.getServerDescription().getAddress(), transactionContext); transactionContext.release(); // The session is responsible for retaining a reference to the context } - return source; - } else { - return wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress())); } + return new SessionBindingConnectionSource(source); } private class SessionBindingConnectionSource implements ConnectionSource { @@ -183,18 +183,16 @@ public ReadPreference getReadPreference() { @Override public Connection getConnection() { TransactionContext transactionContext = TransactionContext.get(session); - if (transactionContext != null && transactionContext.isConnectionPinningRequired()) { - Connection pinnedConnection = transactionContext.getPinnedConnection(); - if (pinnedConnection == null) { - Connection connection = wrapped.getConnection(); - transactionContext.pinConnection(connection, Connection::markAsPinned); - return connection; - } else { - return pinnedConnection.retain(); - } - } else { + if (transactionContext == null || !transactionContext.isConnectionPinningRequired()) { return wrapped.getConnection(); } + Connection pinnedConnection = transactionContext.getPinnedConnection(); + if (pinnedConnection != null) { + return pinnedConnection.retain(); + } + Connection connection = wrapped.getConnection(); + transactionContext.pinConnection(connection, Connection::markAsPinned); + return connection; } @Override