diff --git a/driver-core/src/main/com/mongodb/internal/ExpirableValue.java b/driver-core/src/main/com/mongodb/internal/ExpirableValue.java new file mode 100644 index 00000000000..74cf13e4cf1 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/ExpirableValue.java @@ -0,0 +1,70 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal; + +import com.mongodb.annotations.ThreadSafe; + +import java.time.Duration; +import java.util.Optional; + +import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE; + +/** + * A value associated with a lifetime. + * + *

Instances are shallowly immutable.

+ *

This class is not part of the public API and may be removed or changed at any time

+ */ +@ThreadSafe +public final class ExpirableValue { + private final T value; + private final long deadline; + + public static ExpirableValue expired() { + return new ExpirableValue<>(null, Duration.ofSeconds(-1), System.nanoTime()); + } + + public static ExpirableValue expirable(final T value, final Duration lifetime) { + return expirable(value, lifetime, System.nanoTime()); + } + + public static ExpirableValue expirable(final T value, final Duration lifetime, final long startNanoTime) { + return new ExpirableValue<>(assertNotNull(value), assertNotNull(lifetime), startNanoTime); + } + + private ExpirableValue(final T value, final Duration lifetime, final long currentNanoTime) { + this.value = value; + deadline = currentNanoTime + lifetime.toNanos(); + } + + /** + * Returns {@link Optional#empty()} if the value is expired. Otherwise, returns an {@link Optional} describing the value. + */ + public Optional getValue() { + return getValue(System.nanoTime()); + } + + @VisibleForTesting(otherwise = PRIVATE) + Optional getValue(final long currentNanoTime) { + if (currentNanoTime - deadline > 0) { + return Optional.empty(); + } else { + return Optional.of(value); + } + } +} diff --git a/driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java b/driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java index 17522e867b0..fa7e55fe034 100644 --- a/driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java +++ b/driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java @@ -17,11 +17,15 @@ package com.mongodb.internal.authentication; import com.mongodb.MongoClientException; +import com.mongodb.internal.ExpirableValue; import org.bson.BsonDocument; +import org.bson.BsonString; import org.bson.json.JsonParseException; +import java.time.Duration; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import static com.mongodb.internal.authentication.HttpHelper.getHttpContents; @@ -31,25 +35,46 @@ *

This class should not be considered a part of the public API.

*/ public final class AzureCredentialHelper { - public static BsonDocument obtainFromEnvironment() { - String endpoint = "http://" + "169.254.169.254:80" - + "/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https://vault.azure.net"; - - Map headers = new HashMap<>(); - headers.put("Metadata", "true"); - headers.put("Accept", "application/json"); - - String response = getHttpContents("GET", endpoint, headers); - try { - BsonDocument responseDocument = BsonDocument.parse(response); - if (responseDocument.containsKey("access_token")) { - return new BsonDocument("accessToken", responseDocument.get("access_token")); - } else { - throw new MongoClientException("The access_token is missing from Azure IMDS metadata response."); + private static final String ACCESS_TOKEN_FIELD = "access_token"; + private static final String EXPIRES_IN_FIELD = "expires_in"; + + private static ExpirableValue cachedAccessToken = ExpirableValue.expired(); + + public static synchronized BsonDocument obtainFromEnvironment() { + String accessToken; + Optional cachedValue = cachedAccessToken.getValue(); + if (cachedValue.isPresent()) { + accessToken = cachedValue.get(); + } else { + String endpoint = "http://" + "169.254.169.254:80" + + "/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https://vault.azure.net"; + + Map headers = new HashMap<>(); + headers.put("Metadata", "true"); + headers.put("Accept", "application/json"); + + long startNanoTime = System.nanoTime(); + BsonDocument responseDocument; + try { + responseDocument = BsonDocument.parse(getHttpContents("GET", endpoint, headers)); + } catch (JsonParseException e) { + throw new MongoClientException("Exception parsing JSON from Azure IMDS metadata response.", e); + } + + if (!responseDocument.isString(ACCESS_TOKEN_FIELD)) { + throw new MongoClientException(String.format( + "The %s field from Azure IMDS metadata response is missing or is not a string", ACCESS_TOKEN_FIELD)); + } + if (!responseDocument.isString(EXPIRES_IN_FIELD)) { + throw new MongoClientException(String.format( + "The %s field from Azure IMDS metadata response is missing or is not a string", EXPIRES_IN_FIELD)); } - } catch (JsonParseException e) { - throw new MongoClientException("Exception parsing JSON from Azure IMDS metadata response.", e); - } + accessToken = responseDocument.getString(ACCESS_TOKEN_FIELD).getValue(); + int expiresInSeconds = Integer.parseInt(responseDocument.getString(EXPIRES_IN_FIELD).getValue()); + cachedAccessToken = ExpirableValue.expirable(accessToken, Duration.ofSeconds(expiresInSeconds).minus(Duration.ofMinutes(1)), + startNanoTime); + } + return new BsonDocument("accessToken", new BsonString(accessToken)); } private AzureCredentialHelper() { diff --git a/driver-core/src/test/unit/com/mongodb/internal/ExpirableValueTest.java b/driver-core/src/test/unit/com/mongodb/internal/ExpirableValueTest.java new file mode 100644 index 00000000000..bce54567bea --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/ExpirableValueTest.java @@ -0,0 +1,65 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal; + +import org.junit.jupiter.api.Test; + +import java.time.Duration; + +import static com.mongodb.internal.ExpirableValue.expired; +import static com.mongodb.internal.ExpirableValue.expirable; +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class ExpirableValueTest { + + @Test + void testExpired() { + assertFalse(expired().getValue().isPresent()); + } + + @SuppressWarnings("OptionalGetWithoutIsPresent") + @Test + void testExpirable() { + assertAll( + () -> assertThrows(AssertionError.class, () -> expirable(null, Duration.ofNanos(1))), + () -> assertThrows(AssertionError.class, () -> expirable(1, null)), + () -> assertFalse(expirable(1, Duration.ofNanos(-1)).getValue().isPresent()), + () -> assertFalse(expirable(1, Duration.ZERO).getValue().isPresent()), + () -> assertEquals(1, expirable(1, Duration.ofSeconds(1)).getValue().get()), + () -> { + ExpirableValue expirableValue = expirable(1, Duration.ofNanos(1)); + Thread.sleep(1); + assertFalse(expirableValue.getValue().isPresent()); + }, + () -> { + ExpirableValue expirableValue = expirable(1, Duration.ofMinutes(60), Long.MAX_VALUE); + assertEquals(1, expirableValue.getValue(Long.MAX_VALUE + Duration.ofMinutes(30).toNanos()).get()); + }, + () -> { + ExpirableValue expirableValue = expirable(1, Duration.ofMinutes(60), Long.MAX_VALUE); + assertEquals(1, expirableValue.getValue(Long.MAX_VALUE + Duration.ofMinutes(30).toNanos()).get()); + assertFalse(expirableValue.getValue(Long.MAX_VALUE + Duration.ofMinutes(61).toNanos()).isPresent()); + }, + () -> { + ExpirableValue expirableValue = expirable(1, Duration.ofNanos(10), Long.MAX_VALUE - 20); + assertFalse(expirableValue.getValue(Long.MAX_VALUE - 20 + Duration.ofNanos(30).toNanos()).isPresent()); + }); + } +}