diff --git a/driver-core/src/main/com/mongodb/internal/connection/DnsMultiServerCluster.java b/driver-core/src/main/com/mongodb/internal/connection/DnsMultiServerCluster.java index 0589d0f7d19..51e28ee5c84 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/DnsMultiServerCluster.java +++ b/driver-core/src/main/com/mongodb/internal/connection/DnsMultiServerCluster.java @@ -21,6 +21,7 @@ import com.mongodb.connection.ClusterId; import com.mongodb.connection.ClusterSettings; import com.mongodb.connection.ClusterType; +import com.mongodb.connection.ServerDescription; import com.mongodb.lang.Nullable; import java.util.ArrayList; @@ -28,6 +29,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; import static com.mongodb.assertions.Assertions.assertNotNull; @@ -38,7 +40,6 @@ public final class DnsMultiServerCluster extends AbstractMultiServerCluster { private final DnsSrvRecordMonitor dnsSrvRecordMonitor; private volatile MongoException srvResolutionException; - public DnsMultiServerCluster(final ClusterId clusterId, final ClusterSettings settings, final ClusterableServerFactory serverFactory, final DnsSrvRecordMonitorFactory dnsSrvRecordMonitorFactory) { super(clusterId, settings, serverFactory); @@ -57,17 +58,33 @@ public void initialize(final Collection hosts) { } } - private Collection applySrvMaxHosts(final Collection hosts) { - Collection newHosts = hosts; + private Collection applySrvMaxHosts(final Collection latestSrvHosts) { Integer srvMaxHosts = getSettings().getSrvMaxHosts(); - if (srvMaxHosts != null && srvMaxHosts > 0) { - if (srvMaxHosts < hosts.size()) { - List newHostsList = new ArrayList<>(hosts); - Collections.shuffle(newHostsList, ThreadLocalRandom.current()); - newHosts = newHostsList.subList(0, srvMaxHosts); - } + if (srvMaxHosts == null || srvMaxHosts <= 0 || latestSrvHosts.size() <= srvMaxHosts) { + return new ArrayList<>(latestSrvHosts); } - return newHosts; + List activeHosts = getActivePriorHosts(latestSrvHosts); + int numNewHostsToAdd = srvMaxHosts - activeHosts.size(); + activeHosts.addAll(addShuffledHosts(latestSrvHosts, activeHosts, numNewHostsToAdd)); + + return activeHosts; + } + + private List getActivePriorHosts(final Collection latestSrvHosts) { + List priorHosts = DnsMultiServerCluster.this.getCurrentDescription().getServerDescriptions().stream() + .map(ServerDescription::getAddress).collect(Collectors.toList()); + priorHosts.removeIf(host -> !latestSrvHosts.contains(host)); + + return priorHosts; + } + + private List addShuffledHosts(final Collection latestSrvHosts, + final List activePriorHosts, final int numNewHostsToAdd) { + List addedHosts = new ArrayList<>(latestSrvHosts); + addedHosts.removeAll(activePriorHosts); + Collections.shuffle(addedHosts, ThreadLocalRandom.current()); + + return addedHosts.subList(0, numNewHostsToAdd); } @Override diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/SrvPollingProseTests.java b/driver-core/src/test/unit/com/mongodb/internal/connection/SrvPollingProseTests.java index a6605725cf8..a0f08a82360 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/SrvPollingProseTests.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/SrvPollingProseTests.java @@ -160,9 +160,10 @@ public void shouldUseAllRecordsWhenSrvMaxHostsIsGreaterThanOrEqualToNumSrvRecord public void shouldUseSrvMaxHostsWhenSrvMaxHostsIsLessThanNumSrvRecords() { int srvMaxHosts = 2; List updatedHosts = asList(firstHost, thirdHost, fourthHost); - initCluster(updatedHosts, srvMaxHosts); + assertEquals(srvMaxHosts, clusterHostsSet().size()); + assertTrue(updatedHosts.contains(firstHost)); assertTrue(updatedHosts.containsAll(clusterHostsSet())); }