Skip to content

Replace synchronized with ReentrantLock #984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions driver-core/src/main/com/mongodb/KerberosSubjectProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import javax.security.auth.kerberos.KerberosTicket;
import javax.security.auth.login.LoginContext;
import javax.security.auth.login.LoginException;
import java.util.concurrent.locks.ReentrantLock;

import static com.mongodb.assertions.Assertions.notNull;
import static com.mongodb.internal.Locks.checkedWithLock;
import static java.lang.String.format;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.MINUTES;
Expand All @@ -54,6 +56,7 @@ public class KerberosSubjectProvider implements SubjectProvider {
private static final Logger LOGGER = Loggers.getLogger("authenticator");
private static final String TGT_PREFIX = "krbtgt/";

private final ReentrantLock lock = new ReentrantLock();
private String loginContextName;
private String fallbackLoginContextName;
private Subject subject;
Expand Down Expand Up @@ -87,11 +90,13 @@ private KerberosSubjectProvider(final String loginContextName, @Nullable final S
* @throws LoginException any exception resulting from a call to {@link LoginContext#login()}
*/
@NonNull
public synchronized Subject getSubject() throws LoginException {
if (subject == null || needNewSubject(subject)) {
subject = createNewSubject();
}
return subject;
public Subject getSubject() throws LoginException {
return checkedWithLock(lock, () -> {
if (subject == null || needNewSubject(subject)) {
subject = createNewSubject();
}
return subject;
});
}

private Subject createNewSubject() throws LoginException {
Expand Down
32 changes: 32 additions & 0 deletions driver-core/src/main/com/mongodb/internal/CheckedSupplier.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright 2008-present MongoDB, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.mongodb.internal;

/**
* This class is not part of the public API and may be removed or changed at any time.
*/
@FunctionalInterface
public interface CheckedSupplier<T, E extends Exception> {

/**
* Gets a result.
*
* @return a result
* @throws E the checked exception to throw
*/
T get() throws E;
}
55 changes: 55 additions & 0 deletions driver-core/src/main/com/mongodb/internal/Locks.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright 2008-present MongoDB, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.mongodb.internal;

import com.mongodb.MongoInterruptedException;

import java.util.concurrent.locks.Lock;
import java.util.function.Supplier;

/**
* This class is not part of the public API and may be removed or changed at any time.
*/
public final class Locks {
public static void withLock(final Lock lock, final Runnable action) {
withLock(lock, () -> {
action.run();
return null;
});
}

public static <V> V withLock(final Lock lock, final Supplier<V> supplier) {
return checkedWithLock(lock, supplier::get);
}

public static <V, E extends Exception> V checkedWithLock(final Lock lock, final CheckedSupplier<V, E> supplier) throws E {
try {
lock.lockInterruptibly();
try {
return supplier.get();
} finally {
lock.unlock();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new MongoInterruptedException("Interrupted waiting for lock", e);
}
}

private Locks() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.mongodb.event.ClusterDescriptionChangedEvent;
import com.mongodb.event.ClusterListener;
import com.mongodb.event.ClusterOpeningEvent;
import com.mongodb.internal.Locks;
import com.mongodb.internal.VisibleForTesting;
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.internal.selector.LatencyMinimizingServerSelector;
Expand All @@ -48,6 +49,7 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;

import static com.mongodb.assertions.Assertions.isTrue;
Expand All @@ -68,6 +70,7 @@ abstract class BaseCluster implements Cluster {

private static final Logger LOGGER = Loggers.getLogger("cluster");

private final ReentrantLock lock = new ReentrantLock();
private final AtomicReference<CountDownLatch> phase = new AtomicReference<CountDownLatch>(new CountDownLatch(1));
private final ClusterableServerFactory serverFactory;
private final ClusterId clusterId;
Expand Down Expand Up @@ -268,8 +271,8 @@ public ClusterDescription getCurrentDescription() {
}

@Override
public synchronized void withLock(final Runnable action) {
action.run();
public void withLock(final Runnable action) {
Locks.withLock(lock, action);
}

private void updatePhase() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,36 @@
import org.bson.BsonDocument;
import org.bson.BsonTimestamp;

import java.util.concurrent.locks.ReentrantLock;

import static com.mongodb.internal.Locks.withLock;

public class ClusterClock {
private static final String CLUSTER_TIME_KEY = "clusterTime";
private final ReentrantLock lock = new ReentrantLock();
private BsonDocument clusterTime;

public synchronized BsonDocument getCurrent() {
return clusterTime;
public BsonDocument getCurrent() {
return withLock(lock, () -> clusterTime);
}

public synchronized BsonTimestamp getClusterTime() {
return clusterTime != null ? clusterTime.getTimestamp(CLUSTER_TIME_KEY) : null;
public BsonTimestamp getClusterTime() {
return withLock(lock, () -> clusterTime != null ? clusterTime.getTimestamp(CLUSTER_TIME_KEY) : null);
}

public synchronized void advance(final BsonDocument other) {
this.clusterTime = greaterOf(other);
public void advance(final BsonDocument other) {
withLock(lock, () -> this.clusterTime = greaterOf(other));
}

public synchronized BsonDocument greaterOf(final BsonDocument other) {
if (other == null) {
return clusterTime;
} else if (clusterTime == null) {
return other;
} else {
return other.getTimestamp(CLUSTER_TIME_KEY).compareTo(clusterTime.getTimestamp(CLUSTER_TIME_KEY)) > 0 ? other : clusterTime;
}
public BsonDocument greaterOf(final BsonDocument other) {
return withLock(lock, () -> {
if (other == null) {
return clusterTime;
} else if (clusterTime == null) {
return other;
} else {
return other.getTimestamp(CLUSTER_TIME_KEY).compareTo(clusterTime.getTimestamp(CLUSTER_TIME_KEY)) > 0 ? other : clusterTime;
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import com.mongodb.AuthenticationMechanism;
import com.mongodb.MongoCredential;

import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import static com.mongodb.internal.Locks.withLock;

public class MongoCredentialWithCache {
private final MongoCredential credential;
private final Cache cache;
Expand Down Expand Up @@ -53,21 +58,29 @@ public void putInCache(final Object key, final Object value) {
cache.set(key, value);
}

public Lock getLock() {
return cache.lock;
}

static class Cache {
private final ReentrantLock lock = new ReentrantLock();
private Object cacheKey;
private Object cacheValue;

synchronized Object get(final Object key) {
if (cacheKey != null && cacheKey.equals(key)) {
return cacheValue;
}
return null;
Object get(final Object key) {
return withLock(lock, () -> {
if (cacheKey != null && cacheKey.equals(key)) {
return cacheValue;
}
return null;
});
}

synchronized void set(final Object key, final Object value) {
cacheKey = key;
cacheValue = value;
void set(final Object key, final Object value) {
withLock(lock, () -> {
cacheKey = key;
cacheValue = value;
});
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@

import static com.mongodb.MongoCredential.JAVA_SUBJECT_KEY;
import static com.mongodb.MongoCredential.JAVA_SUBJECT_PROVIDER_KEY;
import static com.mongodb.internal.Locks.withLock;
import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback;
import static com.mongodb.internal.connection.CommandHelper.executeCommand;
import static com.mongodb.internal.connection.CommandHelper.executeCommandAsync;

abstract class SaslAuthenticator extends Authenticator implements SpeculativeAuthenticator {
public static final Logger LOGGER = Loggers.getLogger("authenticator");
private static final String SUBJECT_PROVIDER_CACHE_KEY = "SUBJECT_PROVIDER";

SaslAuthenticator(final MongoCredentialWithCache credential, final ClusterConnectionMode clusterConnectionMode,
final @Nullable ServerApi serverApi) {
super(credential, clusterConnectionMode, serverApi);
Expand Down Expand Up @@ -205,7 +205,7 @@ protected Subject getSubject() {

@NonNull
private SubjectProvider getSubjectProvider() {
synchronized (getMongoCredentialWithCache()) {
return withLock(getMongoCredentialWithCache().getLock(), () -> {
SubjectProvider subjectProvider =
getMongoCredentialWithCache().getFromCache(SUBJECT_PROVIDER_CACHE_KEY, SubjectProvider.class);
if (subjectProvider == null) {
Expand All @@ -216,7 +216,7 @@ private SubjectProvider getSubjectProvider() {
getMongoCredentialWithCache().putInCache(SUBJECT_PROVIDER_CACHE_KEY, subjectProvider);
}
return subjectProvider;
}
});
}

@NonNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@
package com.mongodb.client.gridfs;

import com.mongodb.MongoGridFSException;
import com.mongodb.client.ClientSession;
import com.mongodb.client.FindIterable;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoCursor;
import com.mongodb.client.gridfs.model.GridFSFile;
import com.mongodb.lang.Nullable;
import com.mongodb.client.ClientSession;
import org.bson.BsonValue;
import org.bson.Document;
import org.bson.types.Binary;

import java.util.concurrent.locks.ReentrantLock;

import static com.mongodb.assertions.Assertions.isTrueArgument;
import static com.mongodb.assertions.Assertions.notNull;
import static com.mongodb.internal.Locks.withLock;
import static java.lang.String.format;

class GridFSDownloadStreamImpl extends GridFSDownloadStream {
Expand All @@ -47,8 +50,8 @@ class GridFSDownloadStreamImpl extends GridFSDownloadStream {
private byte[] buffer = null;
private long markPosition;

private final Object closeLock = new Object();
private final Object cursorLock = new Object();
private final ReentrantLock closeLock = new ReentrantLock();
private final ReentrantLock cursorLock = new ReentrantLock();
private boolean closed = false;

GridFSDownloadStreamImpl(@Nullable final ClientSession clientSession, final GridFSFile fileInfo,
Expand Down Expand Up @@ -156,12 +159,12 @@ public void mark() {
}

@Override
public synchronized void mark(final int readlimit) {
public void mark(final int readlimit) {
markPosition = currentPosition;
}

@Override
public synchronized void reset() {
public void reset() {
checkClosed();
if (currentPosition == markPosition) {
return;
Expand All @@ -184,29 +187,29 @@ public boolean markSupported() {

@Override
public void close() {
synchronized (closeLock) {
withLock(closeLock, () -> {
if (!closed) {
closed = true;
}
discardCursor();
}
});
}

private void checkClosed() {
synchronized (closeLock) {
withLock(closeLock, () -> {
if (closed) {
throw new MongoGridFSException("The InputStream has been closed");
}
}
});
}

private void discardCursor() {
synchronized (cursorLock) {
withLock(cursorLock, () -> {
if (cursor != null) {
cursor.close();
cursor = null;
}
}
});
}

@Nullable
Expand Down
Loading