Skip to content

Commit 2b47064

Browse files
committed
Implement OIDC auth for async (#1131)
JAVA-4981
1 parent 0b9b8ca commit 2b47064

File tree

10 files changed

+228
-80
lines changed

10 files changed

+228
-80
lines changed

driver-core/src/main/com/mongodb/assertions/Assertions.java

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package com.mongodb.assertions;
1919

20-
import com.mongodb.internal.async.SingleResultCallback;
2120
import com.mongodb.lang.Nullable;
2221

2322
import java.util.Collection;
@@ -79,25 +78,6 @@ public static <T> Iterable<T> notNullElements(final String name, final Iterable<
7978
return values;
8079
}
8180

82-
/**
83-
* Throw IllegalArgumentException if the value is null.
84-
*
85-
* @param name the parameter name
86-
* @param value the value that should not be null
87-
* @param callback the callback that also is passed the exception if the value is null
88-
* @param <T> the value type
89-
* @return the value
90-
* @throws java.lang.IllegalArgumentException if value is null
91-
*/
92-
public static <T> T notNull(final String name, final T value, final SingleResultCallback<?> callback) {
93-
if (value == null) {
94-
IllegalArgumentException exception = new IllegalArgumentException(name + " can not be null");
95-
callback.completeExceptionally(exception);
96-
throw exception;
97-
}
98-
return value;
99-
}
100-
10181
/**
10282
* Throw IllegalStateException if the condition if false.
10383
*
@@ -111,22 +91,6 @@ public static void isTrue(final String name, final boolean condition) {
11191
}
11292
}
11393

114-
/**
115-
* Throw IllegalStateException if the condition if false.
116-
*
117-
* @param name the name of the state that is being checked
118-
* @param condition the condition about the parameter to check
119-
* @param callback the callback that also is passed the exception if the condition is not true
120-
* @throws java.lang.IllegalStateException if the condition is false
121-
*/
122-
public static void isTrue(final String name, final boolean condition, final SingleResultCallback<?> callback) {
123-
if (!condition) {
124-
IllegalStateException exception = new IllegalStateException("state should be: " + name);
125-
callback.completeExceptionally(exception);
126-
throw exception;
127-
}
128-
}
129-
13094
/**
13195
* Throw IllegalArgumentException if the condition if false.
13296
*

driver-core/src/main/com/mongodb/internal/Locks.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package com.mongodb.internal;
1818

1919
import com.mongodb.MongoInterruptedException;
20+
import com.mongodb.internal.async.AsyncRunnable;
21+
import com.mongodb.internal.async.SingleResultCallback;
2022

2123
import java.util.concurrent.locks.Lock;
2224
import java.util.concurrent.locks.ReentrantLock;
@@ -36,7 +38,23 @@ public static void withLock(final Lock lock, final Runnable action) {
3638
});
3739
}
3840

39-
public static <V> V withLock(final StampedLock lock, final Supplier<V> supplier) {
41+
public static void withLockAsync(final StampedLock lock, final AsyncRunnable runnable,
42+
final SingleResultCallback<Void> callback) {
43+
long stamp;
44+
try {
45+
stamp = lock.writeLockInterruptibly();
46+
} catch (InterruptedException e) {
47+
Thread.currentThread().interrupt();
48+
callback.onResult(null, new MongoInterruptedException("Interrupted waiting for lock", e));
49+
return;
50+
}
51+
52+
runnable.thenAlwaysRunAndFinish(() -> {
53+
lock.unlockWrite(stamp);
54+
}, callback);
55+
}
56+
57+
public static void withLock(final StampedLock lock, final Runnable runnable) {
4058
long stamp;
4159
try {
4260
stamp = lock.writeLockInterruptibly();
@@ -45,7 +63,7 @@ public static <V> V withLock(final StampedLock lock, final Supplier<V> supplier)
4563
throw new MongoInterruptedException("Interrupted waiting for lock", e);
4664
}
4765
try {
48-
return supplier.get();
66+
runnable.run();
4967
} finally {
5068
lock.unlockWrite(stamp);
5169
}

driver-core/src/main/com/mongodb/internal/connection/Authenticator.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import com.mongodb.lang.Nullable;
2828

2929
import static com.mongodb.assertions.Assertions.notNull;
30+
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
3031

3132
/**
3233
* <p>This class is not part of the public API and may be removed or changed at any time</p>
@@ -104,4 +105,10 @@ public void reauthenticate(final InternalConnection connection) {
104105
authenticate(connection, connection.getDescription());
105106
}
106107

108+
public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback<Void> callback) {
109+
beginAsync().thenRun((c) -> {
110+
authenticateAsync(connection, connection.getDescription(), c);
111+
}).finish(callback);
112+
}
113+
107114
}

driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public interface InternalConnection extends BufferProvider {
4949
ServerDescription getInitialServerDescription();
5050

5151
/**
52-
* Opens the connection so its ready for use
52+
* Opens the connection so its ready for use. Will perform a handshake.
5353
*/
5454
void open();
5555

driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import com.mongodb.event.CommandListener;
4343
import com.mongodb.internal.ResourceUtil;
4444
import com.mongodb.internal.VisibleForTesting;
45+
import com.mongodb.internal.async.AsyncSupplier;
4546
import com.mongodb.internal.async.SingleResultCallback;
4647
import com.mongodb.internal.diagnostics.logging.Logger;
4748
import com.mongodb.internal.diagnostics.logging.Loggers;
@@ -68,9 +69,12 @@
6869
import java.util.function.Supplier;
6970

7071
import static com.mongodb.assertions.Assertions.assertNotNull;
72+
import static com.mongodb.assertions.Assertions.assertNull;
7173
import static com.mongodb.assertions.Assertions.isTrue;
7274
import static com.mongodb.assertions.Assertions.notNull;
75+
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
7376
import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback;
77+
import static com.mongodb.internal.connection.Authenticator.shouldAuthenticate;
7478
import static com.mongodb.internal.connection.CommandHelper.HELLO;
7579
import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO;
7680
import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO_LOWER;
@@ -238,7 +242,7 @@ public void open() {
238242

239243
@Override
240244
public void openAsync(final SingleResultCallback<Void> callback) {
241-
isTrue("Open already called", stream == null, callback);
245+
assertNull(stream);
242246
try {
243247
stream = streamFactory.create(serverId.getAddress());
244248
stream.openAsync(new AsyncCompletionHandler<Void>() {
@@ -364,17 +368,48 @@ public <T> T sendAndReceive(final CommandMessage message, final Decoder<T> decod
364368
try {
365369
return sendAndReceiveInternal.get();
366370
} catch (MongoCommandException e) {
367-
if (triggersReauthentication(e) && Authenticator.shouldAuthenticate(authenticator, this.description)) {
368-
authenticated.set(false);
369-
authenticator.reauthenticate(this);
370-
authenticated.set(true);
371-
return sendAndReceiveInternal.get();
371+
if (reauthenticationIsTriggered(e)) {
372+
return reauthenticateAndRetry(sendAndReceiveInternal);
372373
}
373374
throw e;
374375
}
375376
}
376377

377-
public static boolean triggersReauthentication(@Nullable final Throwable t) {
378+
@Override
379+
public <T> void sendAndReceiveAsync(final CommandMessage message, final Decoder<T> decoder, final SessionContext sessionContext,
380+
final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback<T> callback) {
381+
382+
AsyncSupplier<T> sendAndReceiveAsyncInternal = c -> sendAndReceiveAsyncInternal(
383+
message, decoder, sessionContext, requestContext, operationContext, c);
384+
beginAsync().<T>thenSupply(c -> {
385+
sendAndReceiveAsyncInternal.getAsync(c);
386+
}).onErrorIf(e -> reauthenticationIsTriggered(e), c -> {
387+
reauthenticateAndRetryAsync(sendAndReceiveAsyncInternal, c);
388+
}).finish(callback);
389+
}
390+
391+
private <T> T reauthenticateAndRetry(final Supplier<T> operation) {
392+
authenticated.set(false);
393+
assertNotNull(authenticator).reauthenticate(this);
394+
authenticated.set(true);
395+
return operation.get();
396+
}
397+
398+
private <T> void reauthenticateAndRetryAsync(final AsyncSupplier<T> operation,
399+
final SingleResultCallback<T> callback) {
400+
beginAsync().thenRun(c -> {
401+
authenticated.set(false);
402+
assertNotNull(authenticator).reauthenticateAsync(this, c);
403+
}).<T>thenSupply((c) -> {
404+
authenticated.set(true);
405+
operation.getAsync(c);
406+
}).finish(callback);
407+
}
408+
409+
public boolean reauthenticationIsTriggered(@Nullable final Throwable t) {
410+
if (!shouldAuthenticate(authenticator, this.description)) {
411+
return false;
412+
}
378413
if (t instanceof MongoCommandException) {
379414
MongoCommandException e = (MongoCommandException) t;
380415
return e.getErrorCode() == 391;
@@ -501,11 +536,8 @@ private <T> T receiveCommandMessageResponse(final Decoder<T> decoder,
501536
}
502537
}
503538

504-
@Override
505-
public <T> void sendAndReceiveAsync(final CommandMessage message, final Decoder<T> decoder, final SessionContext sessionContext,
539+
private <T> void sendAndReceiveAsyncInternal(final CommandMessage message, final Decoder<T> decoder, final SessionContext sessionContext,
506540
final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback<T> callback) {
507-
notNull("stream is open", stream, callback);
508-
509541
if (isClosed()) {
510542
callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress()));
511543
return;
@@ -616,7 +648,7 @@ public void sendMessage(final List<ByteBuf> byteBuffers, final int lastRequestId
616648

617649
@Override
618650
public ResponseBuffers receiveMessage(final int responseTo) {
619-
notNull("stream is open", stream);
651+
assertNotNull(stream);
620652
if (isClosed()) {
621653
throw new MongoSocketClosedException("Cannot read from a closed stream", getServerAddress());
622654
}
@@ -634,8 +666,9 @@ private ResponseBuffers receiveMessageWithAdditionalTimeout(final int additional
634666
}
635667

636668
@Override
637-
public void sendMessageAsync(final List<ByteBuf> byteBuffers, final int lastRequestId, final SingleResultCallback<Void> callback) {
638-
notNull("stream is open", stream, callback);
669+
public void sendMessageAsync(final List<ByteBuf> byteBuffers, final int lastRequestId,
670+
final SingleResultCallback<Void> callback) {
671+
assertNotNull(stream);
639672

640673
if (isClosed()) {
641674
callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress()));
@@ -667,7 +700,7 @@ public void failed(final Throwable t) {
667700

668701
@Override
669702
public void receiveMessageAsync(final int responseTo, final SingleResultCallback<ResponseBuffers> callback) {
670-
isTrue("stream is open", stream != null, callback);
703+
assertNotNull(stream);
671704

672705
if (isClosed()) {
673706
callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress()));

driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import java.util.concurrent.locks.StampedLock;
2626

2727
import static com.mongodb.internal.Locks.withInterruptibleLock;
28-
import static com.mongodb.internal.Locks.withLock;
2928
import static com.mongodb.internal.connection.OidcAuthenticator.OidcCacheEntry;
3029

3130
/**

0 commit comments

Comments
 (0)