Skip to content

Commit c53b9ce

Browse files
committed
Use async API in ClientSessionBinding.getConnection/Source
1 parent 35061d0 commit c53b9ce

File tree

2 files changed

+78
-91
lines changed

2 files changed

+78
-91
lines changed

driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionBinding.java

+49-63
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@
2222
import com.mongodb.ServerApi;
2323
import com.mongodb.connection.ClusterType;
2424
import com.mongodb.connection.ServerDescription;
25+
import com.mongodb.internal.async.AsyncSupplier;
2526
import com.mongodb.internal.async.SingleResultCallback;
26-
import com.mongodb.internal.async.function.AsyncCallbackSupplier;
2727
import com.mongodb.internal.binding.AbstractReferenceCounted;
2828
import com.mongodb.internal.binding.AsyncClusterAwareReadWriteBinding;
2929
import com.mongodb.internal.binding.AsyncConnectionSource;
3030
import com.mongodb.internal.binding.AsyncReadWriteBinding;
3131
import com.mongodb.internal.binding.TransactionContext;
3232
import com.mongodb.internal.connection.AsyncConnection;
33+
import com.mongodb.internal.connection.Connection;
3334
import com.mongodb.internal.connection.OperationContext;
3435
import com.mongodb.internal.session.ClientSessionContext;
3536
import com.mongodb.internal.session.SessionContext;
@@ -41,6 +42,7 @@
4142
import static com.mongodb.assertions.Assertions.notNull;
4243
import static com.mongodb.connection.ClusterType.LOAD_BALANCED;
4344
import static com.mongodb.connection.ClusterType.SHARDED;
45+
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
4446

4547
/**
4648
* <p>This class is not part of the public API and may be removed or changed at any time</p>
@@ -92,40 +94,41 @@ public void getReadConnectionSource(final SingleResultCallback<AsyncConnectionSo
9294
@Override
9395
public void getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference,
9496
final SingleResultCallback<AsyncConnectionSource> callback) {
95-
getConnectionSource(wrappedConnectionSourceCallback ->
96-
wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, wrappedConnectionSourceCallback),
97-
callback);
97+
AsyncSupplier<AsyncConnectionSource> supplier = callback2 ->
98+
wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, callback2);
99+
getConnectionSource(supplier, callback);
98100
}
99101

100102
public void getWriteConnectionSource(final SingleResultCallback<AsyncConnectionSource> callback) {
101103
getConnectionSource(wrapped::getWriteConnectionSource, callback);
102104
}
103105

104-
private void getConnectionSource(final AsyncCallbackSupplier<AsyncConnectionSource> connectionSourceSupplier,
106+
private void getConnectionSource(final AsyncSupplier<AsyncConnectionSource> connectionSourceSupplier,
105107
final SingleResultCallback<AsyncConnectionSource> callback) {
106-
WrappingCallback wrappingCallback = new WrappingCallback(callback);
107-
108-
if (!session.hasActiveTransaction()) {
109-
connectionSourceSupplier.get(wrappingCallback);
110-
return;
111-
}
112-
if (TransactionContext.get(session) == null) {
113-
connectionSourceSupplier.get((source, t) -> {
114-
if (t != null) {
115-
wrappingCallback.onResult(null, t);
116-
} else {
117-
ClusterType clusterType = assertNotNull(source).getServerDescription().getClusterType();
118-
if (clusterType == SHARDED || clusterType == LOAD_BALANCED) {
119-
TransactionContext<AsyncConnection> transactionContext = new TransactionContext<>(clusterType);
120-
session.setTransactionContext(source.getServerDescription().getAddress(), transactionContext);
121-
transactionContext.release(); // The session is responsible for retaining a reference to the context
122-
}
123-
wrappingCallback.onResult(source, null);
124-
}
125-
});
126-
} else {
127-
wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress()), wrappingCallback);
128-
}
108+
// wrapper applied at end
109+
beginAsync().<AsyncConnectionSource>thenSupply(c -> {
110+
if (!session.hasActiveTransaction()) {
111+
connectionSourceSupplier.getAsync(c);
112+
return;
113+
}
114+
if (TransactionContext.get(session) != null) {
115+
wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress()), c);
116+
return;
117+
}
118+
beginAsync().<AsyncConnectionSource>thenSupply(c2 -> {
119+
connectionSourceSupplier.getAsync(c2);
120+
}).<AsyncConnectionSource>thenApply((source, c2) -> {
121+
ClusterType clusterType = assertNotNull(source).getServerDescription().getClusterType();
122+
if (clusterType == SHARDED || clusterType == LOAD_BALANCED) {
123+
TransactionContext<AsyncConnection> transactionContext = new TransactionContext<>(clusterType);
124+
session.setTransactionContext(source.getServerDescription().getAddress(), transactionContext);
125+
transactionContext.release(); // The session is responsible for retaining a reference to the context
126+
} //
127+
c2.complete(source);
128+
}).finish(c);
129+
}).<AsyncConnectionSource>thenApply((source, c) -> {
130+
c.complete(new SessionBindingAsyncConnectionSource(assertNotNull(source)));
131+
}).finish(callback);
129132
}
130133

131134
@Override
@@ -187,24 +190,24 @@ public ReadPreference getReadPreference() {
187190

188191
@Override
189192
public void getConnection(final SingleResultCallback<AsyncConnection> callback) {
190-
TransactionContext<AsyncConnection> transactionContext = TransactionContext.get(session);
191-
if (transactionContext != null && transactionContext.isConnectionPinningRequired()) {
192-
AsyncConnection pinnedConnection = transactionContext.getPinnedConnection();
193-
if (pinnedConnection == null) {
194-
wrapped.getConnection((connection, t) -> {
195-
if (t != null) {
196-
callback.onResult(null, t);
197-
} else {
198-
transactionContext.pinConnection(assertNotNull(connection), AsyncConnection::markAsPinned);
199-
callback.onResult(connection, null);
200-
}
201-
});
202-
} else {
203-
callback.onResult(pinnedConnection.retain(), null);
204-
}
205-
} else {
206-
wrapped.getConnection(callback);
207-
}
193+
beginAsync().<AsyncConnection>thenSupply(c -> {
194+
TransactionContext<AsyncConnection> transactionContext = TransactionContext.get(session);
195+
if (transactionContext == null || !transactionContext.isConnectionPinningRequired()) {
196+
wrapped.getConnection(c);
197+
return;
198+
} //
199+
AsyncConnection pinnedAsyncConnection = transactionContext.getPinnedConnection();
200+
if (pinnedAsyncConnection != null) {
201+
c.complete(pinnedAsyncConnection.retain());
202+
return;
203+
} //
204+
beginAsync().<AsyncConnection>thenSupply(c2 -> {
205+
wrapped.getConnection(c2);
206+
}).<AsyncConnection>thenApply((connection, c2) -> {
207+
transactionContext.pinConnection(assertNotNull(connection), AsyncConnection::markAsPinned);
208+
c2.complete(connection);
209+
}).finish(c);
210+
}).finish(callback);
208211
}
209212

210213
@Override
@@ -281,21 +284,4 @@ public ReadConcern getReadConcern() {
281284
}
282285
}
283286
}
284-
285-
private class WrappingCallback implements SingleResultCallback<AsyncConnectionSource> {
286-
private final SingleResultCallback<AsyncConnectionSource> callback;
287-
288-
WrappingCallback(final SingleResultCallback<AsyncConnectionSource> callback) {
289-
this.callback = callback;
290-
}
291-
292-
@Override
293-
public void onResult(@Nullable final AsyncConnectionSource result, @Nullable final Throwable t) {
294-
if (t != null) {
295-
callback.onResult(null, t);
296-
} else {
297-
callback.onResult(new SessionBindingAsyncConnectionSource(assertNotNull(result)), null);
298-
}
299-
}
300-
}
301287
}

driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java

+29-28
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import com.mongodb.internal.session.SessionContext;
3535
import com.mongodb.lang.Nullable;
3636

37+
import java.util.function.Function;
3738
import java.util.function.Supplier;
3839

3940
import static com.mongodb.connection.ClusterType.LOAD_BALANCED;
@@ -89,17 +90,18 @@ public int release() {
8990

9091
@Override
9192
public ConnectionSource getReadConnectionSource() {
92-
return new SessionBindingConnectionSource(getConnectionSource(wrapped::getReadConnectionSource));
93+
return getConnectionSource(wrapped::getReadConnectionSource);
9394
}
9495

9596
@Override
9697
public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference) {
97-
return new SessionBindingConnectionSource(getConnectionSource(() ->
98-
wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference)));
98+
Supplier<ConnectionSource> supplier = () ->
99+
wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference);
100+
return getConnectionSource(supplier);
99101
}
100102

101103
public ConnectionSource getWriteConnectionSource() {
102-
return new SessionBindingConnectionSource(getConnectionSource(wrapped::getWriteConnectionSource));
104+
return getConnectionSource(wrapped::getWriteConnectionSource);
103105
}
104106

105107
@Override
@@ -123,23 +125,24 @@ public OperationContext getOperationContext() {
123125
return wrapped.getOperationContext();
124126
}
125127

126-
private ConnectionSource getConnectionSource(final Supplier<ConnectionSource> wrappedConnectionSourceSupplier) {
128+
private ConnectionSource getConnectionSource(final Supplier<ConnectionSource> connectionSourceSupplier) {
129+
Function<ConnectionSource, ConnectionSource> wrapper = c -> new SessionBindingConnectionSource(c);
130+
127131
if (!session.hasActiveTransaction()) {
128-
return wrappedConnectionSourceSupplier.get();
132+
return wrapper.apply(connectionSourceSupplier.get());
129133
}
130-
131-
if (TransactionContext.get(session) == null) {
132-
ConnectionSource source = wrappedConnectionSourceSupplier.get();
133-
ClusterType clusterType = source.getServerDescription().getClusterType();
134-
if (clusterType == SHARDED || clusterType == LOAD_BALANCED) {
135-
TransactionContext<Connection> transactionContext = new TransactionContext<>(clusterType);
136-
session.setTransactionContext(source.getServerDescription().getAddress(), transactionContext);
137-
transactionContext.release(); // The session is responsible for retaining a reference to the context
138-
}
139-
return source;
140-
} else {
141-
return wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress()));
134+
if (TransactionContext.get(session) != null) {
135+
return wrapper.apply(
136+
wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress())));
142137
}
138+
ConnectionSource source = connectionSourceSupplier.get();
139+
ClusterType clusterType = source.getServerDescription().getClusterType();
140+
if (clusterType == SHARDED || clusterType == LOAD_BALANCED) {
141+
TransactionContext<Connection> transactionContext = new TransactionContext<>(clusterType);
142+
session.setTransactionContext(source.getServerDescription().getAddress(), transactionContext);
143+
transactionContext.release(); // The session is responsible for retaining a reference to the context
144+
}
145+
return wrapper.apply(source);
143146
}
144147

145148
private class SessionBindingConnectionSource implements ConnectionSource {
@@ -183,18 +186,16 @@ public ReadPreference getReadPreference() {
183186
@Override
184187
public Connection getConnection() {
185188
TransactionContext<Connection> transactionContext = TransactionContext.get(session);
186-
if (transactionContext != null && transactionContext.isConnectionPinningRequired()) {
187-
Connection pinnedConnection = transactionContext.getPinnedConnection();
188-
if (pinnedConnection == null) {
189-
Connection connection = wrapped.getConnection();
190-
transactionContext.pinConnection(connection, Connection::markAsPinned);
191-
return connection;
192-
} else {
193-
return pinnedConnection.retain();
194-
}
195-
} else {
189+
if (transactionContext == null || !transactionContext.isConnectionPinningRequired()) {
196190
return wrapped.getConnection();
197191
}
192+
Connection pinnedConnection = transactionContext.getPinnedConnection();
193+
if (pinnedConnection != null) {
194+
return pinnedConnection.retain();
195+
}
196+
Connection connection = wrapped.getConnection();
197+
transactionContext.pinConnection(connection, Connection::markAsPinned);
198+
return connection;
198199
}
199200

200201
@Override

0 commit comments

Comments
 (0)