diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java index 709f88e4db1f..aeba8a3114f2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java @@ -16,17 +16,23 @@ package org.springframework.messaging.simp.broker; +import static java.util.Collections.emptySet; + import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.Queue; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; +import java.util.stream.Stream; import org.springframework.expression.EvaluationContext; import org.springframework.expression.Expression; @@ -186,10 +192,8 @@ private Expression getSelectorExpression(MessageHeaders headers) { protected void removeSubscriptionInternal(String sessionId, String subscriptionId, Message message) { SessionInfo info = this.sessionRegistry.getSession(sessionId); if (info != null) { - Subscription subscription = info.removeSubscription(subscriptionId); - if (subscription != null) { - this.destinationCache.updateAfterRemovedSubscription(sessionId, subscription); - } + info.removeSubscription(subscriptionId) + .forEach(subscription -> this.destinationCache.updateAfterRemovedSubscription(sessionId, subscription)); } } @@ -212,10 +216,11 @@ protected MultiValueMap findSubscriptionsInternal(String destina SessionInfo info = this.sessionRegistry.getSession(sessionId); if (info != null) { for (String subscriptionId : subscriptionIds) { - Subscription subscription = info.getSubscription(subscriptionId); - if (subscription != null && evaluateExpression(subscription.getSelector(), message)) { - result.add(sessionId, subscription.getId()); - } + info.getSubscriptions(subscriptionId).forEach(subscription -> { + if (evaluateExpression(subscription.getSelector(), message)) { + result.add(sessionId, subscription.getId()); + } + }); } } }); @@ -371,9 +376,7 @@ private void removeInternal(String destination, String sessionId, String subscri } public void updateAfterRemovedSession(String sessionId, SessionInfo info) { - for (Subscription subscription : info.getSubscriptions()) { - updateAfterRemovedSubscription(sessionId, subscription); - } + info.getSubscriptions().forEach(subscription -> updateAfterRemovedSubscription(sessionId, subscription)); } } @@ -411,24 +414,25 @@ public SessionInfo removeSubscriptions(String sessionId) { private static final class SessionInfo { // subscriptionId -> Subscription - private final Map subscriptionMap = new ConcurrentHashMap<>(); + private final Map> subscriptionMap = new ConcurrentHashMap<>(); - public Collection getSubscriptions() { - return this.subscriptionMap.values(); + public Stream getSubscriptions() { + return this.subscriptionMap.values().stream().flatMap(Set::stream); } - @Nullable - public Subscription getSubscription(String subscriptionId) { - return this.subscriptionMap.get(subscriptionId); + public Stream getSubscriptions(String subscriptionId) { + return this.subscriptionMap.getOrDefault(subscriptionId, emptySet()).stream(); } public void addSubscription(Subscription subscription) { - this.subscriptionMap.putIfAbsent(subscription.getId(), subscription); + this.subscriptionMap.computeIfAbsent(subscription.getId(), id -> new HashSet<>()) + .add(subscription); } - @Nullable - public Subscription removeSubscription(String subscriptionId) { - return this.subscriptionMap.remove(subscriptionId); + public Stream removeSubscription(String subscriptionId) { + return Optional.ofNullable(this.subscriptionMap.remove(subscriptionId)) + .map(Set::stream) + .orElse(Stream.empty()); } } @@ -474,18 +478,25 @@ public Expression getSelector() { @Override public boolean equals(@Nullable Object other) { - return (this == other || - (other instanceof Subscription && this.id.equals(((Subscription) other).id))); + if (this == other) { + return true; + } + if (other instanceof Subscription) { + Subscription that = (Subscription)other; + return Objects.equals(this.id, that.id) + && Objects.equals(this.destination, that.destination); + } + return false; } @Override public int hashCode() { - return this.id.hashCode(); + return Objects.hash(this.id, this.destination); } @Override public String toString() { - return "subscription(id=" + this.id + ")"; + return "subscription(id=" + this.id + "; destination=" + this.destination + ")"; } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java index 18970dfbb5c5..d453f3e43aca 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java @@ -110,6 +110,27 @@ public void registerSameSubscriptionTwice() { assertThat(actual.size()).isEqualTo(1); assertThat(actual.get(sessId)).containsExactly(subId); } + + @Test + public void registerSameSubscriptionForDifferentDestinations() { + String sessId = "sess01"; + String subId = "subs01"; + String dest1 = "/foo"; + String dest2 = "user/foo"; + + this.registry.registerSubscription(subscribeMessage(sessId, subId, dest1)); + this.registry.registerSubscription(subscribeMessage(sessId, subId, dest2)); + + MultiValueMap actual = this.registry.findSubscriptions(createMessage(dest1)); + assertThat(actual).isNotNull(); + assertThat(actual.size()).isEqualTo(1); + assertThat(actual.get(sessId)).containsExactly(subId); + + actual = this.registry.findSubscriptions(createMessage(dest2)); + assertThat(actual).isNotNull(); + assertThat(actual.size()).isEqualTo(1); + assertThat(actual.get(sessId)).containsExactly(subId); + } @Test public void registerSubscriptionMultipleSessions() {