Skip to content

Commit b5b395c

Browse files
authored
Cache Azure credential obtained from environment (#1038)
JAVA-4706
1 parent 1ef1b5e commit b5b395c

File tree

3 files changed

+178
-18
lines changed

3 files changed

+178
-18
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mongodb.internal;
18+
19+
import com.mongodb.annotations.ThreadSafe;
20+
21+
import java.time.Duration;
22+
import java.util.Optional;
23+
24+
import static com.mongodb.assertions.Assertions.assertNotNull;
25+
import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE;
26+
27+
/**
28+
* A value associated with a lifetime.
29+
*
30+
* <p>Instances are shallowly immutable.</p>
31+
* <p>This class is not part of the public API and may be removed or changed at any time</p>
32+
*/
33+
@ThreadSafe
34+
public final class ExpirableValue<T> {
35+
private final T value;
36+
private final long deadline;
37+
38+
public static <T> ExpirableValue<T> expired() {
39+
return new ExpirableValue<>(null, Duration.ofSeconds(-1), System.nanoTime());
40+
}
41+
42+
public static <T> ExpirableValue<T> expirable(final T value, final Duration lifetime) {
43+
return expirable(value, lifetime, System.nanoTime());
44+
}
45+
46+
public static <T> ExpirableValue<T> expirable(final T value, final Duration lifetime, final long startNanoTime) {
47+
return new ExpirableValue<>(assertNotNull(value), assertNotNull(lifetime), startNanoTime);
48+
}
49+
50+
private ExpirableValue(final T value, final Duration lifetime, final long currentNanoTime) {
51+
this.value = value;
52+
deadline = currentNanoTime + lifetime.toNanos();
53+
}
54+
55+
/**
56+
* Returns {@link Optional#empty()} if the value is expired. Otherwise, returns an {@link Optional} describing the value.
57+
*/
58+
public Optional<T> getValue() {
59+
return getValue(System.nanoTime());
60+
}
61+
62+
@VisibleForTesting(otherwise = PRIVATE)
63+
Optional<T> getValue(final long currentNanoTime) {
64+
if (currentNanoTime - deadline > 0) {
65+
return Optional.empty();
66+
} else {
67+
return Optional.of(value);
68+
}
69+
}
70+
}

driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
package com.mongodb.internal.authentication;
1818

1919
import com.mongodb.MongoClientException;
20+
import com.mongodb.internal.ExpirableValue;
2021
import org.bson.BsonDocument;
22+
import org.bson.BsonString;
2123
import org.bson.json.JsonParseException;
2224

25+
import java.time.Duration;
2326
import java.util.HashMap;
2427
import java.util.Map;
28+
import java.util.Optional;
2529

2630
import static com.mongodb.internal.authentication.HttpHelper.getHttpContents;
2731

@@ -31,25 +35,46 @@
3135
* <p>This class should not be considered a part of the public API.</p>
3236
*/
3337
public final class AzureCredentialHelper {
34-
public static BsonDocument obtainFromEnvironment() {
35-
String endpoint = "http://" + "169.254.169.254:80"
36-
+ "/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https://vault.azure.net";
37-
38-
Map<String, String> headers = new HashMap<>();
39-
headers.put("Metadata", "true");
40-
headers.put("Accept", "application/json");
41-
42-
String response = getHttpContents("GET", endpoint, headers);
43-
try {
44-
BsonDocument responseDocument = BsonDocument.parse(response);
45-
if (responseDocument.containsKey("access_token")) {
46-
return new BsonDocument("accessToken", responseDocument.get("access_token"));
47-
} else {
48-
throw new MongoClientException("The access_token is missing from Azure IMDS metadata response.");
38+
private static final String ACCESS_TOKEN_FIELD = "access_token";
39+
private static final String EXPIRES_IN_FIELD = "expires_in";
40+
41+
private static ExpirableValue<String> cachedAccessToken = ExpirableValue.expired();
42+
43+
public static synchronized BsonDocument obtainFromEnvironment() {
44+
String accessToken;
45+
Optional<String> cachedValue = cachedAccessToken.getValue();
46+
if (cachedValue.isPresent()) {
47+
accessToken = cachedValue.get();
48+
} else {
49+
String endpoint = "http://" + "169.254.169.254:80"
50+
+ "/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https://vault.azure.net";
51+
52+
Map<String, String> headers = new HashMap<>();
53+
headers.put("Metadata", "true");
54+
headers.put("Accept", "application/json");
55+
56+
long startNanoTime = System.nanoTime();
57+
BsonDocument responseDocument;
58+
try {
59+
responseDocument = BsonDocument.parse(getHttpContents("GET", endpoint, headers));
60+
} catch (JsonParseException e) {
61+
throw new MongoClientException("Exception parsing JSON from Azure IMDS metadata response.", e);
62+
}
63+
64+
if (!responseDocument.isString(ACCESS_TOKEN_FIELD)) {
65+
throw new MongoClientException(String.format(
66+
"The %s field from Azure IMDS metadata response is missing or is not a string", ACCESS_TOKEN_FIELD));
67+
}
68+
if (!responseDocument.isString(EXPIRES_IN_FIELD)) {
69+
throw new MongoClientException(String.format(
70+
"The %s field from Azure IMDS metadata response is missing or is not a string", EXPIRES_IN_FIELD));
4971
}
50-
} catch (JsonParseException e) {
51-
throw new MongoClientException("Exception parsing JSON from Azure IMDS metadata response.", e);
52-
}
72+
accessToken = responseDocument.getString(ACCESS_TOKEN_FIELD).getValue();
73+
int expiresInSeconds = Integer.parseInt(responseDocument.getString(EXPIRES_IN_FIELD).getValue());
74+
cachedAccessToken = ExpirableValue.expirable(accessToken, Duration.ofSeconds(expiresInSeconds).minus(Duration.ofMinutes(1)),
75+
startNanoTime);
76+
}
77+
return new BsonDocument("accessToken", new BsonString(accessToken));
5378
}
5479

5580
private AzureCredentialHelper() {
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mongodb.internal;
18+
19+
import org.junit.jupiter.api.Test;
20+
21+
import java.time.Duration;
22+
23+
import static com.mongodb.internal.ExpirableValue.expired;
24+
import static com.mongodb.internal.ExpirableValue.expirable;
25+
import static org.junit.jupiter.api.Assertions.assertAll;
26+
import static org.junit.jupiter.api.Assertions.assertEquals;
27+
import static org.junit.jupiter.api.Assertions.assertFalse;
28+
import static org.junit.jupiter.api.Assertions.assertThrows;
29+
30+
class ExpirableValueTest {
31+
32+
@Test
33+
void testExpired() {
34+
assertFalse(expired().getValue().isPresent());
35+
}
36+
37+
@SuppressWarnings("OptionalGetWithoutIsPresent")
38+
@Test
39+
void testExpirable() {
40+
assertAll(
41+
() -> assertThrows(AssertionError.class, () -> expirable(null, Duration.ofNanos(1))),
42+
() -> assertThrows(AssertionError.class, () -> expirable(1, null)),
43+
() -> assertFalse(expirable(1, Duration.ofNanos(-1)).getValue().isPresent()),
44+
() -> assertFalse(expirable(1, Duration.ZERO).getValue().isPresent()),
45+
() -> assertEquals(1, expirable(1, Duration.ofSeconds(1)).getValue().get()),
46+
() -> {
47+
ExpirableValue<Integer> expirableValue = expirable(1, Duration.ofNanos(1));
48+
Thread.sleep(1);
49+
assertFalse(expirableValue.getValue().isPresent());
50+
},
51+
() -> {
52+
ExpirableValue<Integer> expirableValue = expirable(1, Duration.ofMinutes(60), Long.MAX_VALUE);
53+
assertEquals(1, expirableValue.getValue(Long.MAX_VALUE + Duration.ofMinutes(30).toNanos()).get());
54+
},
55+
() -> {
56+
ExpirableValue<Integer> expirableValue = expirable(1, Duration.ofMinutes(60), Long.MAX_VALUE);
57+
assertEquals(1, expirableValue.getValue(Long.MAX_VALUE + Duration.ofMinutes(30).toNanos()).get());
58+
assertFalse(expirableValue.getValue(Long.MAX_VALUE + Duration.ofMinutes(61).toNanos()).isPresent());
59+
},
60+
() -> {
61+
ExpirableValue<Integer> expirableValue = expirable(1, Duration.ofNanos(10), Long.MAX_VALUE - 20);
62+
assertFalse(expirableValue.getValue(Long.MAX_VALUE - 20 + Duration.ofNanos(30).toNanos()).isPresent());
63+
});
64+
}
65+
}

0 commit comments

Comments
 (0)