|
22 | 22 | import com.mongodb.ServerApi;
|
23 | 23 | import com.mongodb.connection.ClusterType;
|
24 | 24 | import com.mongodb.connection.ServerDescription;
|
| 25 | +import com.mongodb.internal.async.AsyncSupplier; |
25 | 26 | import com.mongodb.internal.async.SingleResultCallback;
|
26 |
| -import com.mongodb.internal.async.function.AsyncCallbackSupplier; |
27 | 27 | import com.mongodb.internal.binding.AbstractReferenceCounted;
|
28 | 28 | import com.mongodb.internal.binding.AsyncClusterAwareReadWriteBinding;
|
29 | 29 | import com.mongodb.internal.binding.AsyncConnectionSource;
|
30 | 30 | import com.mongodb.internal.binding.AsyncReadWriteBinding;
|
31 | 31 | import com.mongodb.internal.binding.TransactionContext;
|
32 | 32 | import com.mongodb.internal.connection.AsyncConnection;
|
| 33 | +import com.mongodb.internal.connection.Connection; |
33 | 34 | import com.mongodb.internal.connection.OperationContext;
|
34 | 35 | import com.mongodb.internal.session.ClientSessionContext;
|
35 | 36 | import com.mongodb.internal.session.SessionContext;
|
|
41 | 42 | import static com.mongodb.assertions.Assertions.notNull;
|
42 | 43 | import static com.mongodb.connection.ClusterType.LOAD_BALANCED;
|
43 | 44 | import static com.mongodb.connection.ClusterType.SHARDED;
|
| 45 | +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; |
44 | 46 |
|
45 | 47 | /**
|
46 | 48 | * <p>This class is not part of the public API and may be removed or changed at any time</p>
|
@@ -92,40 +94,41 @@ public void getReadConnectionSource(final SingleResultCallback<AsyncConnectionSo
|
92 | 94 | @Override
|
93 | 95 | public void getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference,
|
94 | 96 | final SingleResultCallback<AsyncConnectionSource> callback) {
|
95 |
| - getConnectionSource(wrappedConnectionSourceCallback -> |
96 |
| - wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, wrappedConnectionSourceCallback), |
97 |
| - callback); |
| 97 | + AsyncSupplier<AsyncConnectionSource> supplier = callback2 -> |
| 98 | + wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, callback2); |
| 99 | + getConnectionSource(supplier, callback); |
98 | 100 | }
|
99 | 101 |
|
100 | 102 | public void getWriteConnectionSource(final SingleResultCallback<AsyncConnectionSource> callback) {
|
101 | 103 | getConnectionSource(wrapped::getWriteConnectionSource, callback);
|
102 | 104 | }
|
103 | 105 |
|
104 |
| - private void getConnectionSource(final AsyncCallbackSupplier<AsyncConnectionSource> connectionSourceSupplier, |
| 106 | + private void getConnectionSource(final AsyncSupplier<AsyncConnectionSource> connectionSourceSupplier, |
105 | 107 | final SingleResultCallback<AsyncConnectionSource> callback) {
|
106 |
| - WrappingCallback wrappingCallback = new WrappingCallback(callback); |
107 |
| - |
108 |
| - if (!session.hasActiveTransaction()) { |
109 |
| - connectionSourceSupplier.get(wrappingCallback); |
110 |
| - return; |
111 |
| - } |
112 |
| - if (TransactionContext.get(session) == null) { |
113 |
| - connectionSourceSupplier.get((source, t) -> { |
114 |
| - if (t != null) { |
115 |
| - wrappingCallback.onResult(null, t); |
116 |
| - } else { |
117 |
| - ClusterType clusterType = assertNotNull(source).getServerDescription().getClusterType(); |
118 |
| - if (clusterType == SHARDED || clusterType == LOAD_BALANCED) { |
119 |
| - TransactionContext<AsyncConnection> transactionContext = new TransactionContext<>(clusterType); |
120 |
| - session.setTransactionContext(source.getServerDescription().getAddress(), transactionContext); |
121 |
| - transactionContext.release(); // The session is responsible for retaining a reference to the context |
122 |
| - } |
123 |
| - wrappingCallback.onResult(source, null); |
124 |
| - } |
125 |
| - }); |
126 |
| - } else { |
127 |
| - wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress()), wrappingCallback); |
128 |
| - } |
| 108 | + // wrapper applied at end |
| 109 | + beginAsync().<AsyncConnectionSource>thenSupply(c -> { |
| 110 | + if (!session.hasActiveTransaction()) { |
| 111 | + connectionSourceSupplier.getAsync(c); |
| 112 | + return; |
| 113 | + } |
| 114 | + if (TransactionContext.get(session) != null) { |
| 115 | + wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress()), c); |
| 116 | + return; |
| 117 | + } |
| 118 | + beginAsync().<AsyncConnectionSource>thenSupply(c2 -> { |
| 119 | + connectionSourceSupplier.getAsync(c2); |
| 120 | + }).<AsyncConnectionSource>thenApply((source, c2) -> { |
| 121 | + ClusterType clusterType = assertNotNull(source).getServerDescription().getClusterType(); |
| 122 | + if (clusterType == SHARDED || clusterType == LOAD_BALANCED) { |
| 123 | + TransactionContext<AsyncConnection> transactionContext = new TransactionContext<>(clusterType); |
| 124 | + session.setTransactionContext(source.getServerDescription().getAddress(), transactionContext); |
| 125 | + transactionContext.release(); // The session is responsible for retaining a reference to the context |
| 126 | + } // |
| 127 | + c2.complete(source); |
| 128 | + }).finish(c); |
| 129 | + }).<AsyncConnectionSource>thenApply((source, c) -> { |
| 130 | + c.complete(new SessionBindingAsyncConnectionSource(assertNotNull(source))); |
| 131 | + }).finish(callback); |
129 | 132 | }
|
130 | 133 |
|
131 | 134 | @Override
|
@@ -187,24 +190,24 @@ public ReadPreference getReadPreference() {
|
187 | 190 |
|
188 | 191 | @Override
|
189 | 192 | public void getConnection(final SingleResultCallback<AsyncConnection> callback) {
|
190 |
| - TransactionContext<AsyncConnection> transactionContext = TransactionContext.get(session); |
191 |
| - if (transactionContext != null && transactionContext.isConnectionPinningRequired()) { |
192 |
| - AsyncConnection pinnedConnection = transactionContext.getPinnedConnection(); |
193 |
| - if (pinnedConnection == null) { |
194 |
| - wrapped.getConnection((connection, t) -> { |
195 |
| - if (t != null) { |
196 |
| - callback.onResult(null, t); |
197 |
| - } else { |
198 |
| - transactionContext.pinConnection(assertNotNull(connection), AsyncConnection::markAsPinned); |
199 |
| - callback.onResult(connection, null); |
200 |
| - } |
201 |
| - }); |
202 |
| - } else { |
203 |
| - callback.onResult(pinnedConnection.retain(), null); |
204 |
| - } |
205 |
| - } else { |
206 |
| - wrapped.getConnection(callback); |
207 |
| - } |
| 193 | + beginAsync().<AsyncConnection>thenSupply(c -> { |
| 194 | + TransactionContext<AsyncConnection> transactionContext = TransactionContext.get(session); |
| 195 | + if (transactionContext == null || !transactionContext.isConnectionPinningRequired()) { |
| 196 | + wrapped.getConnection(c); |
| 197 | + return; |
| 198 | + } // |
| 199 | + AsyncConnection pinnedAsyncConnection = transactionContext.getPinnedConnection(); |
| 200 | + if (pinnedAsyncConnection != null) { |
| 201 | + c.complete(pinnedAsyncConnection.retain()); |
| 202 | + return; |
| 203 | + } // |
| 204 | + beginAsync().<AsyncConnection>thenSupply(c2 -> { |
| 205 | + wrapped.getConnection(c2); |
| 206 | + }).<AsyncConnection>thenApply((connection, c2) -> { |
| 207 | + transactionContext.pinConnection(assertNotNull(connection), AsyncConnection::markAsPinned); |
| 208 | + c2.complete(connection); |
| 209 | + }).finish(c); |
| 210 | + }).finish(callback); |
208 | 211 | }
|
209 | 212 |
|
210 | 213 | @Override
|
@@ -281,21 +284,4 @@ public ReadConcern getReadConcern() {
|
281 | 284 | }
|
282 | 285 | }
|
283 | 286 | }
|
284 |
| - |
285 |
| - private class WrappingCallback implements SingleResultCallback<AsyncConnectionSource> { |
286 |
| - private final SingleResultCallback<AsyncConnectionSource> callback; |
287 |
| - |
288 |
| - WrappingCallback(final SingleResultCallback<AsyncConnectionSource> callback) { |
289 |
| - this.callback = callback; |
290 |
| - } |
291 |
| - |
292 |
| - @Override |
293 |
| - public void onResult(@Nullable final AsyncConnectionSource result, @Nullable final Throwable t) { |
294 |
| - if (t != null) { |
295 |
| - callback.onResult(null, t); |
296 |
| - } else { |
297 |
| - callback.onResult(new SessionBindingAsyncConnectionSource(assertNotNull(result)), null); |
298 |
| - } |
299 |
| - } |
300 |
| - } |
301 | 287 | }
|
0 commit comments