diff --git a/driver-core/src/main/com/mongodb/internal/connection/BaseCluster.java b/driver-core/src/main/com/mongodb/internal/connection/BaseCluster.java index 317b83b8b8f..b7747d0b3dc 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/BaseCluster.java +++ b/driver-core/src/main/com/mongodb/internal/connection/BaseCluster.java @@ -165,56 +165,6 @@ public void selectServerAsync(final ServerSelector serverSelector, final Operati } } - @Override - public ClusterDescription getDescription() { - isTrue("open", !isClosed()); - - try { - CountDownLatch currentPhase = phase.get(); - ClusterDescription curDescription = description; - - boolean selectionFailureLogged = false; - - long startTimeNanos = System.nanoTime(); - long curTimeNanos = startTimeNanos; - long maxWaitTimeNanos = getMaxWaitTimeNanos(); - - while (curDescription.getType() == ClusterType.UNKNOWN) { - - if (curTimeNanos - startTimeNanos > maxWaitTimeNanos) { - throw new MongoTimeoutException(format("Timed out after %d ms while waiting to connect. Client view of cluster state " - + "is %s", - settings.getServerSelectionTimeout(MILLISECONDS), - curDescription.getShortDescription())); - } - - if (!selectionFailureLogged) { - if (LOGGER.isInfoEnabled()) { - if (settings.getServerSelectionTimeout(MILLISECONDS) < 0) { - LOGGER.info("Cluster description not yet available. Waiting indefinitely."); - } else { - LOGGER.info(format("Cluster description not yet available. Waiting for %d ms before timing out", - settings.getServerSelectionTimeout(MILLISECONDS))); - } - } - selectionFailureLogged = true; - } - - connect(); - - currentPhase.await(Math.min(maxWaitTimeNanos - (curTimeNanos - startTimeNanos), getMinWaitTimeNanos()), NANOSECONDS); - - curTimeNanos = System.nanoTime(); - - currentPhase = phase.get(); - curDescription = description; - } - return curDescription; - } catch (InterruptedException e) { - throw interruptAndCreateMongoInterruptedException("Interrupted while waiting to connect", e); - } - } - public ClusterId getClusterId() { return clusterId; } diff --git a/driver-core/src/main/com/mongodb/internal/connection/Cluster.java b/driver-core/src/main/com/mongodb/internal/connection/Cluster.java index eb409c7851d..a3a649b10a6 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/Cluster.java +++ b/driver-core/src/main/com/mongodb/internal/connection/Cluster.java @@ -40,15 +40,6 @@ public interface Cluster extends Closeable { ClusterSettings getSettings(); - /** - * Get the description of this cluster. This method will not return normally until the cluster type is known. - * - * @return a ClusterDescription representing the current state of the cluster - * @throws com.mongodb.MongoTimeoutException if the timeout has been reached before the cluster type is known - * @throws com.mongodb.MongoInterruptedException if interrupted when getting the cluster description - */ - ClusterDescription getDescription(); - ClusterId getClusterId(); diff --git a/driver-core/src/main/com/mongodb/internal/connection/LoadBalancedCluster.java b/driver-core/src/main/com/mongodb/internal/connection/LoadBalancedCluster.java index bf910995106..c4bbf695b59 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/LoadBalancedCluster.java +++ b/driver-core/src/main/com/mongodb/internal/connection/LoadBalancedCluster.java @@ -174,13 +174,6 @@ public ClusterSettings getSettings() { return settings; } - @Override - public ClusterDescription getDescription() { - isTrue("open", !isClosed()); - waitForSrv(); - return description; - } - @Override public ClusterId getClusterId() { return clusterId; diff --git a/driver-core/src/test/functional/com/mongodb/ClusterFixture.java b/driver-core/src/test/functional/com/mongodb/ClusterFixture.java index af96a431e5c..ba7acd78704 100644 --- a/driver-core/src/test/functional/com/mongodb/ClusterFixture.java +++ b/driver-core/src/test/functional/com/mongodb/ClusterFixture.java @@ -19,6 +19,7 @@ import com.mongodb.async.FutureResultCallback; import com.mongodb.connection.AsynchronousSocketChannelStreamFactory; import com.mongodb.connection.ClusterConnectionMode; +import com.mongodb.connection.ClusterDescription; import com.mongodb.connection.ClusterSettings; import com.mongodb.connection.ClusterType; import com.mongodb.connection.ConnectionPoolSettings; @@ -85,6 +86,7 @@ import static com.mongodb.connection.ClusterType.REPLICA_SET; import static com.mongodb.connection.ClusterType.SHARDED; import static com.mongodb.connection.ClusterType.STANDALONE; +import static com.mongodb.connection.ClusterType.UNKNOWN; import static com.mongodb.internal.connection.ClusterDescriptionHelper.getPrimaries; import static com.mongodb.internal.connection.ClusterDescriptionHelper.getSecondaries; import static com.mongodb.internal.thread.InterruptionUtil.interruptAndCreateMongoInterruptedException; @@ -140,7 +142,20 @@ public static String getDefaultDatabaseName() { } public static boolean clusterIsType(final ClusterType clusterType) { - return getCluster().getDescription().getType() == clusterType; + return getClusterDescription(getCluster()).getType() == clusterType; + } + + public static ClusterDescription getClusterDescription(final Cluster cluster) { + try { + ClusterDescription clusterDescription = cluster.getCurrentDescription(); + while (clusterDescription.getType() == UNKNOWN) { + Thread.sleep(10); + clusterDescription = cluster.getCurrentDescription(); + } + return clusterDescription; + } catch (InterruptedException e) { + throw interruptAndCreateMongoInterruptedException("Interrupted", e); + } } public static ServerVersion getServerVersion() { @@ -449,27 +464,27 @@ public static SslSettings getSslSettings(final ConnectionString connectionString } public static ServerAddress getPrimary() { - List serverDescriptions = getPrimaries(getCluster().getDescription()); + List serverDescriptions = getPrimaries(getClusterDescription(getCluster())); while (serverDescriptions.isEmpty()) { try { sleep(100); } catch (InterruptedException e) { throw new RuntimeException(e); } - serverDescriptions = getPrimaries(getCluster().getDescription()); + serverDescriptions = getPrimaries(getClusterDescription(getCluster())); } return serverDescriptions.get(0).getAddress(); } public static ServerAddress getSecondary() { - List serverDescriptions = getSecondaries(getCluster().getDescription()); + List serverDescriptions = getSecondaries(getClusterDescription(getCluster())); while (serverDescriptions.isEmpty()) { try { sleep(100); } catch (InterruptedException e) { throw new RuntimeException(e); } - serverDescriptions = getSecondaries(getCluster().getDescription()); + serverDescriptions = getSecondaries(getClusterDescription(getCluster())); } return serverDescriptions.get(0).getAddress(); } @@ -499,20 +514,19 @@ public static BsonDocument getServerParameters() { } public static boolean isDiscoverableReplicaSet() { - return getCluster().getDescription().getType() == REPLICA_SET - && getCluster().getDescription().getConnectionMode() == MULTIPLE; + return clusterIsType(REPLICA_SET) && getClusterConnectionMode() == MULTIPLE; } public static boolean isSharded() { - return getCluster().getDescription().getType() == SHARDED; + return clusterIsType(SHARDED); } public static boolean isStandalone() { - return getCluster().getDescription().getType() == STANDALONE; + return clusterIsType(STANDALONE); } public static boolean isLoadBalanced() { - return getCluster().getSettings().getMode() == LOAD_BALANCED; + return getClusterConnectionMode() == LOAD_BALANCED; } public static boolean isAuthenticated() { diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/SingleServerClusterTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/SingleServerClusterTest.java index 1f8ad92eaf4..af98ef2fc28 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/SingleServerClusterTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/SingleServerClusterTest.java @@ -79,23 +79,14 @@ public void tearDown() { cluster.close(); } - @Test - public void shouldGetDescription() { - // given - setUpCluster(getPrimary()); - - // expect - assertNotNull(cluster.getDescription()); - } - @Test public void descriptionShouldIncludeSettings() { // given setUpCluster(getPrimary()); // expect - assertNotNull(cluster.getDescription().getClusterSettings()); - assertNotNull(cluster.getDescription().getServerSettings()); + assertNotNull(cluster.getCurrentDescription().getClusterSettings()); + assertNotNull(cluster.getCurrentDescription().getServerSettings()); } @Test diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/BaseClusterSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/BaseClusterSpecification.groovy index 48c21f0e2f1..39c52b23821 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/BaseClusterSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/BaseClusterSpecification.groovy @@ -80,12 +80,6 @@ class BaseClusterSpecification extends Specification { cluster.getCurrentDescription() == new ClusterDescription(clusterSettings.getMode(), ClusterType.UNKNOWN, [], clusterSettings, factory.getSettings()) - when: 'the description is accessed before initialization' - cluster.getDescription() - - then: 'a MongoTimeoutException is thrown' - thrown(MongoTimeoutException) - when: 'a server is selected before initialization' cluster.selectServer({ def clusterDescription -> [] }, new OperationContext()) @@ -193,21 +187,10 @@ class BaseClusterSpecification extends Specification { .exception(new MongoInternalException('oops')) .build()) - cluster.getDescription() - - then: - def e = thrown(MongoTimeoutException) - e.getMessage().startsWith("Timed out after ${serverSelectionTimeoutMS} ms while waiting to connect. " + - 'Client view of cluster state is {type=UNKNOWN') - e.getMessage().contains('{address=localhost:27017, type=UNKNOWN, state=CONNECTING, ' + - 'exception={com.mongodb.MongoInternalException: oops}}') - e.getMessage().contains('{address=localhost:27018, type=UNKNOWN, state=CONNECTING}') - - when: cluster.selectServer(new WritableServerSelector(), new OperationContext()) then: - e = thrown(MongoTimeoutException) + def e = thrown(MongoTimeoutException) e.getMessage().startsWith("Timed out after ${serverSelectionTimeoutMS} ms while waiting for a server " + 'that matches WritableServerSelector. Client view of cluster state is {type=UNKNOWN') e.getMessage().contains('{address=localhost:27017, type=UNKNOWN, state=CONNECTING, ' + @@ -272,37 +255,6 @@ class BaseClusterSpecification extends Specification { cluster?.close() } - @Slow - def 'should wait indefinitely for a cluster description until interrupted'() { - given: - def cluster = new MultiServerCluster(new ClusterId(), - builder().mode(MULTIPLE) - .hosts([firstServer, secondServer, thirdServer]) - .serverSelectionTimeout(-1, SECONDS) - .build(), - factory) - - when: - def latch = new CountDownLatch(1) - def thread = new Thread({ - try { - cluster.getDescription() - } catch (MongoInterruptedException e) { - latch.countDown() - } - }) - thread.start() - sleep(1000) - thread.interrupt() - def interrupted = latch.await(ClusterFixture.TIMEOUT, SECONDS) - - then: - interrupted - - cleanup: - cluster?.close() - } - def 'should select server asynchronously when server is already available'() { given: def cluster = new MultiServerCluster(new ClusterId(), diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/DnsMultiServerClusterSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/DnsMultiServerClusterSpecification.groovy index 75a8572a999..2c381165acd 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/DnsMultiServerClusterSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/DnsMultiServerClusterSpecification.groovy @@ -98,7 +98,7 @@ class DnsMultiServerClusterSpecification extends Specification { factory.sendNotification(secondServer, SHARD_ROUTER) def firstTestServer = factory.getServer(firstServer) def secondTestServer = factory.getServer(secondServer) - def clusterDescription = cluster.getDescription() + def clusterDescription = cluster.getCurrentDescription() then: 'events are generated, description includes hosts, exception is cleared, and servers are open' 2 * clusterListener.clusterDescriptionChanged(_) @@ -112,7 +112,7 @@ class DnsMultiServerClusterSpecification extends Specification { initializer.initialize([secondServer, thirdServer]) factory.sendNotification(secondServer, SHARD_ROUTER) def thirdTestServer = factory.getServer(thirdServer) - clusterDescription = cluster.getDescription() + clusterDescription = cluster.getCurrentDescription() then: 'events are generated, description is updated, and the removed server is closed' 1 * clusterListener.clusterDescriptionChanged(_) @@ -125,7 +125,7 @@ class DnsMultiServerClusterSpecification extends Specification { when: 'the listener is initialized with another exception' initializer.initialize(exception) - clusterDescription = cluster.getDescription() + clusterDescription = cluster.getCurrentDescription() then: 'the exception is ignored' 0 * clusterListener.clusterDescriptionChanged(_) diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/MultiServerClusterSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/MultiServerClusterSpecification.groovy index 66667bd11da..096053a0b11 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/MultiServerClusterSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/MultiServerClusterSpecification.groovy @@ -17,7 +17,6 @@ package com.mongodb.internal.connection -import com.mongodb.MongoTimeoutException import com.mongodb.ServerAddress import com.mongodb.connection.ClusterDescription import com.mongodb.connection.ClusterId @@ -71,21 +70,8 @@ class MultiServerClusterSpecification extends Specification { sendNotification(firstServer, REPLICA_SET_PRIMARY) expect: - cluster.getDescription().clusterSettings != null - cluster.getDescription().serverSettings != null - } - - def 'should timeout waiting for description if no servers connect'() { - given: - def cluster = new MultiServerCluster(CLUSTER_ID, ClusterSettings.builder().mode(MULTIPLE) - .serverSelectionTimeout(1, MILLISECONDS) - .hosts([firstServer]).build(), factory) - - when: - cluster.getDescription() - - then: - thrown(MongoTimeoutException) + cluster.getCurrentDescription().clusterSettings != null + cluster.getCurrentDescription().serverSettings != null } def 'should correct report description when connected to a primary'() { @@ -97,8 +83,8 @@ class MultiServerClusterSpecification extends Specification { sendNotification(firstServer, REPLICA_SET_PRIMARY) then: - cluster.getDescription().type == REPLICA_SET - cluster.getDescription().connectionMode == MULTIPLE + cluster.getCurrentDescription().type == REPLICA_SET + cluster.getCurrentDescription().connectionMode == MULTIPLE } def 'should not get server when closed'() { @@ -123,7 +109,7 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(firstServer, REPLICA_SET_PRIMARY, [firstServer, secondServer, thirdServer]) then: - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) } def 'should discover all hosts in the cluster when notified by a secondary and there is no primary'() { @@ -135,7 +121,7 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(firstServer, REPLICA_SET_SECONDARY, [firstServer, secondServer, thirdServer]) then: - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) } def 'should discover all passives in the cluster'() { @@ -147,7 +133,7 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(firstServer, REPLICA_SET_PRIMARY, [firstServer], [secondServer, thirdServer]) then: - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) } def 'should remove a secondary server whose reported host name does not match the address connected to'() { @@ -160,7 +146,7 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(seedListAddress, REPLICA_SET_SECONDARY, [firstServer, secondServer], firstServer) then: - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer) } def 'should remove a primary server whose reported host name does not match the address connected to'() { @@ -173,7 +159,7 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(seedListAddress, REPLICA_SET_PRIMARY, [firstServer, secondServer], firstServer) then: - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer) } def 'should remove a server when it no longer appears in hosts reported by the primary'() { @@ -188,7 +174,7 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(firstServer, REPLICA_SET_PRIMARY, [firstServer, secondServer]) then: - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer) factory.getServer(thirdServer).isClosed() } @@ -202,8 +188,8 @@ class MultiServerClusterSpecification extends Specification { sendNotification(secondServer, SHARD_ROUTER) then: - cluster.getDescription().type == REPLICA_SET - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer) + cluster.getCurrentDescription().type == REPLICA_SET + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer) } def 'should ignore an empty list of hosts when type is replica set'() { @@ -216,9 +202,9 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(secondServer, REPLICA_SET_GHOST, []) then: - cluster.getDescription().type == REPLICA_SET - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer) - getByServerAddress(cluster.getDescription(), secondServer).getType() == REPLICA_SET_GHOST + cluster.getCurrentDescription().type == REPLICA_SET + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer) + getByServerAddress(cluster.getCurrentDescription(), secondServer).getType() == REPLICA_SET_GHOST } def 'should ignore a host without a replica set name when type is replica set'() { @@ -231,9 +217,9 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(secondServer, REPLICA_SET_GHOST, [firstServer, secondServer], (String) null) // null replica set name then: - cluster.getDescription().type == REPLICA_SET - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer) - getByServerAddress(cluster.getDescription(), secondServer).getType() == REPLICA_SET_GHOST + cluster.getCurrentDescription().type == REPLICA_SET + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer) + getByServerAddress(cluster.getCurrentDescription(), secondServer).getType() == REPLICA_SET_GHOST } def 'should remove a server of the wrong type when type is sharded'() { @@ -247,8 +233,8 @@ class MultiServerClusterSpecification extends Specification { sendNotification(secondServer, REPLICA_SET_PRIMARY) then: - cluster.getDescription().type == SHARDED - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer) + cluster.getCurrentDescription().type == SHARDED + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer) } def 'should remove a server of wrong type from discovered replica set'() { @@ -261,8 +247,8 @@ class MultiServerClusterSpecification extends Specification { sendNotification(secondServer, STANDALONE) then: - cluster.getDescription().type == REPLICA_SET - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, thirdServer) + cluster.getCurrentDescription().type == REPLICA_SET + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, thirdServer) } def 'should not set cluster type when connected to a standalone when seed list size is greater than one'() { @@ -275,10 +261,9 @@ class MultiServerClusterSpecification extends Specification { when: sendNotification(firstServer, STANDALONE) - cluster.getDescription() then: - thrown(MongoTimeoutException) + cluster.getCurrentDescription().getType() == UNKNOWN } def 'should not set cluster type when connected to a replica set ghost until a valid replica set member connects'() { @@ -291,17 +276,16 @@ class MultiServerClusterSpecification extends Specification { when: sendNotification(firstServer, REPLICA_SET_GHOST) - cluster.getDescription() then: - thrown(MongoTimeoutException) + cluster.getCurrentDescription().getType() == UNKNOWN when: sendNotification(secondServer, REPLICA_SET_PRIMARY) then: - cluster.getDescription().type == REPLICA_SET - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) + cluster.getCurrentDescription().type == REPLICA_SET + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) } def 'should invalidate existing primary when a new primary notifies'() { @@ -315,7 +299,7 @@ class MultiServerClusterSpecification extends Specification { then: factory.getDescription(firstServer).state == CONNECTING - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) } def 'should invalidate new primary if its electionId is less than the previously reported electionId'() { @@ -330,7 +314,7 @@ class MultiServerClusterSpecification extends Specification { factory.getDescription(firstServer).state == CONNECTED factory.getDescription(firstServer).type == REPLICA_SET_PRIMARY factory.getDescription(secondServer).state == CONNECTING - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) } def 'should remove a server when a server in the seed list is not in hosts list, it should be removed'() { @@ -343,7 +327,7 @@ class MultiServerClusterSpecification extends Specification { sendNotification(serverAddressAlias, REPLICA_SET_PRIMARY) then: - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) } def 'should retain a Standalone server given a hosts list of size 1'() { @@ -355,8 +339,8 @@ class MultiServerClusterSpecification extends Specification { sendNotification(firstServer, STANDALONE) then: - cluster.getDescription().type == ClusterType.STANDALONE - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer) + cluster.getCurrentDescription().type == ClusterType.STANDALONE + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer) } def 'should remove any Standalone server given a hosts list of size greater than one'() { @@ -370,8 +354,8 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(secondServer, REPLICA_SET_PRIMARY, [secondServer, thirdServer]) then: - !(factory.getDescription(firstServer) in getAll(cluster.getDescription())) - cluster.getDescription().type == REPLICA_SET + !(factory.getDescription(firstServer) in getAll(cluster.getCurrentDescription())) + cluster.getCurrentDescription().type == REPLICA_SET } def 'should remove a member whose replica set name does not match the required one'() { @@ -383,8 +367,8 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(secondServer, REPLICA_SET_PRIMARY, [firstServer, secondServer, thirdServer], 'test2') then: - cluster.getDescription().type == REPLICA_SET - getAll(cluster.getDescription()) == [] as Set + cluster.getCurrentDescription().type == REPLICA_SET + getAll(cluster.getCurrentDescription()) == [] as Set } def 'should throw from getServer if cluster is closed'() { @@ -411,7 +395,7 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(secondServer, REPLICA_SET_SECONDARY, [secondServer]) then: - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, thirdServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, thirdServer) } def 'should add servers from a secondary host list when there is no primary'() { @@ -424,7 +408,7 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(secondServer, REPLICA_SET_SECONDARY, [secondServer, thirdServer]) then: - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) } def 'should add and removes servers from a primary host list when there is a primary'() { @@ -437,13 +421,13 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(firstServer, REPLICA_SET_PRIMARY, [firstServer, thirdServer]) then: - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, thirdServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, thirdServer) when: factory.sendNotification(thirdServer, REPLICA_SET_PRIMARY, [secondServer, thirdServer]) then: - getAll(cluster.getDescription()) == factory.getDescriptions(secondServer, thirdServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(secondServer, thirdServer) } def 'should ignore a secondary host list when there is a primary'() { @@ -456,7 +440,7 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(secondServer, REPLICA_SET_SECONDARY, [secondServer, thirdServer]) then: - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer) } def 'should ignore a notification from a server that is not ok'() { @@ -469,7 +453,7 @@ class MultiServerClusterSpecification extends Specification { factory.sendNotification(secondServer, REPLICA_SET_SECONDARY, [], false) then: - getAll(cluster.getDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) + getAll(cluster.getCurrentDescription()) == factory.getDescriptions(firstServer, secondServer, thirdServer) } def 'should fire cluster events'() { diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ServerDiscoveryAndMonitoringTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ServerDiscoveryAndMonitoringTest.java index 05ffee40ae1..4af47cb9557 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ServerDiscoveryAndMonitoringTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ServerDiscoveryAndMonitoringTest.java @@ -30,6 +30,7 @@ import java.net.URISyntaxException; import java.util.Collection; +import static com.mongodb.ClusterFixture.getClusterDescription; import static com.mongodb.internal.connection.ClusterDescriptionHelper.getPrimaries; import static com.mongodb.internal.event.EventListenerHelper.NO_OP_CLUSTER_LISTENER; import static com.mongodb.internal.event.EventListenerHelper.NO_OP_SERVER_LISTENER; @@ -140,7 +141,7 @@ private void assertTopologyType(final String topologyType) { case "Single": assertTrue(getCluster().getClass() == SingleServerCluster.class || (getCluster().getClass() == MultiServerCluster.class - && getCluster().getDescription().getType() == ClusterType.STANDALONE)); + && getClusterDescription(getCluster()).getType() == ClusterType.STANDALONE)); assertEquals(getClusterType(topologyType, getCluster().getCurrentDescription().getServerDescriptions()), getCluster().getCurrentDescription().getType()); break; diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/SingleServerClusterSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/SingleServerClusterSpecification.groovy index 4e091651cc8..f47ab6644d8 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/SingleServerClusterSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/SingleServerClusterSpecification.groovy @@ -59,9 +59,9 @@ class SingleServerClusterSpecification extends Specification { sendNotification(firstServer, STANDALONE) then: - cluster.getDescription().type == ClusterType.STANDALONE - cluster.getDescription().connectionMode == SINGLE - ClusterDescriptionHelper.getAll(cluster.getDescription()) == getDescriptions() + cluster.getCurrentDescription().type == ClusterType.STANDALONE + cluster.getCurrentDescription().connectionMode == SINGLE + ClusterDescriptionHelper.getAll(cluster.getCurrentDescription()) == getDescriptions() cleanup: cluster?.close() @@ -109,8 +109,8 @@ class SingleServerClusterSpecification extends Specification { sendNotification(firstServer, ServerType.REPLICA_SET_PRIMARY) then: - cluster.getDescription().type == ClusterType.SHARDED - ClusterDescriptionHelper.getAll(cluster.getDescription()) == [] as Set + cluster.getCurrentDescription().type == ClusterType.SHARDED + ClusterDescriptionHelper.getAll(cluster.getCurrentDescription()) == [] as Set cleanup: cluster?.close() @@ -126,8 +126,8 @@ class SingleServerClusterSpecification extends Specification { sendNotification(firstServer, ServerType.REPLICA_SET_PRIMARY, 'test1') then: - cluster.getDescription().type == REPLICA_SET - ClusterDescriptionHelper.getAll(cluster.getDescription()) == getDescriptions() + cluster.getCurrentDescription().type == REPLICA_SET + ClusterDescriptionHelper.getAll(cluster.getCurrentDescription()) == getDescriptions() cleanup: cluster?.close() diff --git a/driver-legacy/src/test/functional/com/mongodb/Fixture.java b/driver-legacy/src/test/functional/com/mongodb/Fixture.java index 4651d3c3c36..53c92a2b445 100644 --- a/driver-legacy/src/test/functional/com/mongodb/Fixture.java +++ b/driver-legacy/src/test/functional/com/mongodb/Fixture.java @@ -21,6 +21,7 @@ import java.util.List; +import static com.mongodb.ClusterFixture.getClusterDescription; import static com.mongodb.ClusterFixture.getServerApi; import static com.mongodb.internal.connection.ClusterDescriptionHelper.getPrimaries; @@ -97,10 +98,10 @@ public static MongoClientOptions getOptions() { public static ServerAddress getPrimary() throws InterruptedException { getMongoClient(); - List serverDescriptions = getPrimaries(mongoClient.getCluster().getDescription()); + List serverDescriptions = getPrimaries(getClusterDescription(mongoClient.getCluster())); while (serverDescriptions.isEmpty()) { Thread.sleep(100); - serverDescriptions = getPrimaries(mongoClient.getCluster().getDescription()); + serverDescriptions = getPrimaries(getClusterDescription(mongoClient.getCluster())); } return serverDescriptions.get(0).getAddress(); } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionBinding.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionBinding.java index 1ff3772cde5..46fa37bf8d2 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionBinding.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionBinding.java @@ -23,6 +23,7 @@ import com.mongodb.connection.ClusterType; import com.mongodb.connection.ServerDescription; import com.mongodb.internal.async.SingleResultCallback; +import com.mongodb.internal.async.function.AsyncCallbackSupplier; import com.mongodb.internal.binding.AbstractReferenceCounted; import com.mongodb.internal.binding.AsyncClusterAwareReadWriteBinding; import com.mongodb.internal.binding.AsyncConnectionSource; @@ -39,6 +40,7 @@ import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.connection.ClusterType.LOAD_BALANCED; +import static com.mongodb.connection.ClusterType.SHARDED; /** *

This class is not part of the public API and may be removed or changed at any time

@@ -61,31 +63,6 @@ public ReadPreference getReadPreference() { return wrapped.getReadPreference(); } - @Override - public void getReadConnectionSource(final SingleResultCallback callback) { - isConnectionSourcePinningRequired((isConnectionSourcePinningRequired, t) -> { - if (t != null) { - callback.onResult(null, t); - } else if (isConnectionSourcePinningRequired) { - getPinnedConnectionSource(true, callback); - } else { - wrapped.getReadConnectionSource(new WrappingCallback(callback)); - } - }); - } - - public void getWriteConnectionSource(final SingleResultCallback callback) { - isConnectionSourcePinningRequired((isConnectionSourcePinningRequired, t) -> { - if (t != null) { - callback.onResult(null, t); - } else if (isConnectionSourcePinningRequired) { - getPinnedConnectionSource(false, callback); - } else { - wrapped.getWriteConnectionSource(new WrappingCallback(callback)); - } - }); - } - @Override public SessionContext getSessionContext() { return sessionContext; @@ -107,28 +84,47 @@ public OperationContext getOperationContext() { return wrapped.getOperationContext(); } - private void getPinnedConnectionSource(final boolean isRead, final SingleResultCallback callback) { + @Override + public void getReadConnectionSource(final SingleResultCallback callback) { + getConnectionSource(wrapped::getReadConnectionSource, callback); + } + + @Override + public void getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference, + final SingleResultCallback callback) { + getConnectionSource(wrappedConnectionSourceCallback -> + wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, wrappedConnectionSourceCallback), + callback); + } + + public void getWriteConnectionSource(final SingleResultCallback callback) { + getConnectionSource(wrapped::getWriteConnectionSource, callback); + } + + private void getConnectionSource(final AsyncCallbackSupplier connectionSourceSupplier, + final SingleResultCallback callback) { WrappingCallback wrappingCallback = new WrappingCallback(callback); - TransactionContext transactionContext = TransactionContext.get(session); - if (transactionContext == null) { - SingleResultCallback connectionSourceCallback = (result, t) -> { + + if (!session.hasActiveTransaction()) { + connectionSourceSupplier.get(wrappingCallback); + return; + } + if (TransactionContext.get(session) == null) { + connectionSourceSupplier.get((source, t) -> { if (t != null) { wrappingCallback.onResult(null, t); } else { - TransactionContext newTransactionContext = new TransactionContext<>( - wrapped.getCluster().getDescription().getType()); - session.setTransactionContext(result.getServerDescription().getAddress(), newTransactionContext); - newTransactionContext.release(); // The session is responsible for retaining a reference to the context - wrappingCallback.onResult(result, null); + ClusterType clusterType = assertNotNull(source).getServerDescription().getClusterType(); + if (clusterType == SHARDED || clusterType == LOAD_BALANCED) { + TransactionContext transactionContext = new TransactionContext<>(clusterType); + session.setTransactionContext(source.getServerDescription().getAddress(), transactionContext); + transactionContext.release(); // The session is responsible for retaining a reference to the context + } + wrappingCallback.onResult(source, null); } - }; - if (isRead) { - wrapped.getReadConnectionSource(connectionSourceCallback); - } else { - wrapped.getWriteConnectionSource(connectionSourceCallback); - } + }); } else { - wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress()), new WrappingCallback(callback)); + wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress()), wrappingCallback); } } @@ -138,12 +134,6 @@ public AsyncReadWriteBinding retain() { return this; } - @Override - public void getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference, - final SingleResultCallback callback) { - wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference, callback); - } - @Override public int release() { int count = super.release(); @@ -156,19 +146,6 @@ public int release() { return count; } - private void isConnectionSourcePinningRequired(final SingleResultCallback callback) { - try { - callback.onResult(isConnectionSourcePinningRequired(), null); - } catch (Exception e) { - callback.onResult(null, e); - } - } - - private boolean isConnectionSourcePinningRequired() { - ClusterType clusterType = wrapped.getCluster().getDescription().getType(); - return session.hasActiveTransaction() && (clusterType == ClusterType.SHARDED || clusterType == LOAD_BALANCED); - } - private class SessionBindingAsyncConnectionSource implements AsyncConnectionSource { private AsyncConnectionSource wrapped; @@ -218,7 +195,7 @@ public void getConnection(final SingleResultCallback callback) if (t != null) { callback.onResult(null, t); } else { - transactionContext.pinConnection(connection, AsyncConnection::markAsPinned); + transactionContext.pinConnection(assertNotNull(connection), AsyncConnection::markAsPinned); callback.onResult(connection, null); } }); diff --git a/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/ClientSessionBindingSpecification.groovy b/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/ClientSessionBindingSpecification.groovy index 296f5665b83..4879fa19466 100644 --- a/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/ClientSessionBindingSpecification.groovy +++ b/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/ClientSessionBindingSpecification.groovy @@ -20,9 +20,6 @@ import com.mongodb.ReadConcern import com.mongodb.ReadPreference import com.mongodb.ServerAddress import com.mongodb.async.FutureResultCallback -import com.mongodb.connection.ClusterConnectionMode -import com.mongodb.connection.ClusterDescription -import com.mongodb.connection.ClusterType import com.mongodb.connection.ServerConnectionState import com.mongodb.connection.ServerDescription import com.mongodb.connection.ServerType @@ -54,15 +51,7 @@ class ClientSessionBindingSpecification extends Specification { def 'should return the session context from the connection source'() { given: def session = Stub(ClientSession) - def wrappedBinding = Mock(AsyncClusterAwareReadWriteBinding) { - getCluster() >> { - Mock(Cluster) { - getDescription() >> { - new ClusterDescription(ClusterConnectionMode.MULTIPLE, ClusterType.REPLICA_SET, []) - } - } - } - } + def wrappedBinding = Mock(AsyncClusterAwareReadWriteBinding) wrappedBinding.retain() >> wrappedBinding def binding = new ClientSessionBinding(session, false, wrappedBinding) @@ -192,9 +181,6 @@ class ClientSessionBindingSpecification extends Specification { .address(new ServerAddress()) .build()), null) } - getDescription() >> { - new ClusterDescription(ClusterConnectionMode.MULTIPLE, ClusterType.REPLICA_SET, []) - } } new AsyncClusterBinding(cluster, ReadPreference.primary(), ReadConcern.DEFAULT, null, IgnorableRequestContext.INSTANCE) } diff --git a/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java b/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java index 41c5dd6fc70..a265ca01a7d 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java +++ b/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java @@ -34,7 +34,10 @@ import com.mongodb.internal.session.SessionContext; import com.mongodb.lang.Nullable; +import java.util.function.Supplier; + import static com.mongodb.connection.ClusterType.LOAD_BALANCED; +import static com.mongodb.connection.ClusterType.SHARDED; import static org.bson.assertions.Assertions.assertNotNull; import static org.bson.assertions.Assertions.notNull; @@ -86,28 +89,17 @@ public int release() { @Override public ConnectionSource getReadConnectionSource() { - if (isConnectionSourcePinningRequired()) { - return new SessionBindingConnectionSource(getPinnedConnectionSource(true)); - } else { - return new SessionBindingConnectionSource(wrapped.getReadConnectionSource()); - } + return new SessionBindingConnectionSource(getConnectionSource(wrapped::getReadConnectionSource)); } @Override public ConnectionSource getReadConnectionSource(final int minWireVersion, final ReadPreference fallbackReadPreference) { - if (isConnectionSourcePinningRequired()) { - return new SessionBindingConnectionSource(getPinnedConnectionSource(true)); - } else { - return new SessionBindingConnectionSource(wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference)); - } + return new SessionBindingConnectionSource(getConnectionSource(() -> + wrapped.getReadConnectionSource(minWireVersion, fallbackReadPreference))); } public ConnectionSource getWriteConnectionSource() { - if (isConnectionSourcePinningRequired()) { - return new SessionBindingConnectionSource(getPinnedConnectionSource(false)); - } else { - return new SessionBindingConnectionSource(wrapped.getWriteConnectionSource()); - } + return new SessionBindingConnectionSource(getConnectionSource(wrapped::getWriteConnectionSource)); } @Override @@ -131,23 +123,23 @@ public OperationContext getOperationContext() { return wrapped.getOperationContext(); } - private boolean isConnectionSourcePinningRequired() { - ClusterType clusterType = wrapped.getCluster().getDescription().getType(); - return session.hasActiveTransaction() && (clusterType == ClusterType.SHARDED || clusterType == LOAD_BALANCED); - } + private ConnectionSource getConnectionSource(final Supplier wrappedConnectionSourceSupplier) { + if (!session.hasActiveTransaction()) { + return wrappedConnectionSourceSupplier.get(); + } - private ConnectionSource getPinnedConnectionSource(final boolean isRead) { - TransactionContext transactionContext = TransactionContext.get(session); - ConnectionSource source; - if (transactionContext == null) { - source = isRead ? wrapped.getReadConnectionSource() : wrapped.getWriteConnectionSource(); - transactionContext = new TransactionContext<>(wrapped.getCluster().getDescription().getType()); - session.setTransactionContext(source.getServerDescription().getAddress(), transactionContext); - transactionContext.release(); // The session is responsible for retaining a reference to the context + if (TransactionContext.get(session) == null) { + ConnectionSource source = wrappedConnectionSourceSupplier.get(); + ClusterType clusterType = source.getServerDescription().getClusterType(); + if (clusterType == SHARDED || clusterType == LOAD_BALANCED) { + TransactionContext transactionContext = new TransactionContext<>(clusterType); + session.setTransactionContext(source.getServerDescription().getAddress(), transactionContext); + transactionContext.release(); // The session is responsible for retaining a reference to the context + } + return source; } else { - source = wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress())); + return wrapped.getConnectionSource(assertNotNull(session.getPinnedServerAddress())); } - return source; } private class SessionBindingConnectionSource implements ConnectionSource { diff --git a/driver-sync/src/test/unit/com/mongodb/client/internal/ClientSessionBindingSpecification.groovy b/driver-sync/src/test/unit/com/mongodb/client/internal/ClientSessionBindingSpecification.groovy index 595672328ad..329e8e9a8b8 100644 --- a/driver-sync/src/test/unit/com/mongodb/client/internal/ClientSessionBindingSpecification.groovy +++ b/driver-sync/src/test/unit/com/mongodb/client/internal/ClientSessionBindingSpecification.groovy @@ -19,9 +19,6 @@ package com.mongodb.client.internal import com.mongodb.ReadConcern import com.mongodb.ReadPreference import com.mongodb.client.ClientSession -import com.mongodb.connection.ClusterConnectionMode -import com.mongodb.connection.ClusterDescription -import com.mongodb.connection.ClusterType import com.mongodb.internal.IgnorableRequestContext import com.mongodb.internal.binding.ClusterBinding import com.mongodb.internal.binding.ConnectionSource @@ -47,15 +44,7 @@ class ClientSessionBindingSpecification extends Specification { def 'should return the session context from the connection source'() { given: def session = Stub(ClientSession) - def wrappedBinding = Mock(ClusterBinding) { - getCluster() >> { - Mock(Cluster) { - getDescription() >> { - new ClusterDescription(ClusterConnectionMode.SINGLE, ClusterType.STANDALONE, []) - } - } - } - } + def wrappedBinding = Mock(ClusterBinding) def binding = new ClientSessionBinding(session, false, wrappedBinding) when: