From 454dcb813ac6f46234c79b1fce7c260c8aa1b2ff Mon Sep 17 00:00:00 2001 From: Alex Nistico Date: Thu, 27 May 2021 20:11:59 +0800 Subject: [PATCH 1/2] Fix DefaultSubscriptionRegistry to allow multiple destinations per subscription --- .../broker/DefaultSubscriptionRegistry.java | 63 +++++++++++-------- .../DefaultSubscriptionRegistryTests.java | 21 +++++++ 2 files changed, 58 insertions(+), 26 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java index 709f88e4db1f..8e0dd2309185 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java @@ -16,17 +16,23 @@ package org.springframework.messaging.simp.broker; +import static java.util.Collections.emptySet; + import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.Queue; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; +import java.util.stream.Stream; import org.springframework.expression.EvaluationContext; import org.springframework.expression.Expression; @@ -186,10 +192,8 @@ private Expression getSelectorExpression(MessageHeaders headers) { protected void removeSubscriptionInternal(String sessionId, String subscriptionId, Message message) { SessionInfo info = this.sessionRegistry.getSession(sessionId); if (info != null) { - Subscription subscription = info.removeSubscription(subscriptionId); - if (subscription != null) { - this.destinationCache.updateAfterRemovedSubscription(sessionId, subscription); - } + info.removeSubscription(subscriptionId) + .forEach(subscription -> this.destinationCache.updateAfterRemovedSubscription(sessionId, subscription)); } } @@ -212,10 +216,11 @@ protected MultiValueMap findSubscriptionsInternal(String destina SessionInfo info = this.sessionRegistry.getSession(sessionId); if (info != null) { for (String subscriptionId : subscriptionIds) { - Subscription subscription = info.getSubscription(subscriptionId); - if (subscription != null && evaluateExpression(subscription.getSelector(), message)) { - result.add(sessionId, subscription.getId()); - } + info.getSubscription(subscriptionId).forEach(subscription -> { + if (evaluateExpression(subscription.getSelector(), message)) { + result.add(sessionId, subscription.getId()); + } + }); } } }); @@ -371,9 +376,7 @@ private void removeInternal(String destination, String sessionId, String subscri } public void updateAfterRemovedSession(String sessionId, SessionInfo info) { - for (Subscription subscription : info.getSubscriptions()) { - updateAfterRemovedSubscription(sessionId, subscription); - } + info.getSubscriptions().forEach(subscription -> updateAfterRemovedSubscription(sessionId, subscription)); } } @@ -411,24 +414,25 @@ public SessionInfo removeSubscriptions(String sessionId) { private static final class SessionInfo { // subscriptionId -> Subscription - private final Map subscriptionMap = new ConcurrentHashMap<>(); + private final Map> subscriptionMap = new ConcurrentHashMap<>(); - public Collection getSubscriptions() { - return this.subscriptionMap.values(); + public Stream getSubscriptions() { + return this.subscriptionMap.values().stream().flatMap(Set::stream); } - @Nullable - public Subscription getSubscription(String subscriptionId) { - return this.subscriptionMap.get(subscriptionId); + public Stream getSubscription(String subscriptionId) { + return this.subscriptionMap.getOrDefault(subscriptionId, emptySet()).stream(); } public void addSubscription(Subscription subscription) { - this.subscriptionMap.putIfAbsent(subscription.getId(), subscription); + this.subscriptionMap.computeIfAbsent(subscription.getId(), id -> new HashSet<>()) + .add(subscription); } - @Nullable - public Subscription removeSubscription(String subscriptionId) { - return this.subscriptionMap.remove(subscriptionId); + public Stream removeSubscription(String subscriptionId) { + return Optional.ofNullable(this.subscriptionMap.remove(subscriptionId)) + .map(Set::stream) + .orElse(Stream.empty()); } } @@ -474,18 +478,25 @@ public Expression getSelector() { @Override public boolean equals(@Nullable Object other) { - return (this == other || - (other instanceof Subscription && this.id.equals(((Subscription) other).id))); + if (this == other) { + return true; + } + if (other instanceof Subscription) { + Subscription that = (Subscription)other; + return Objects.equals(this.id, that.id) + && Objects.equals(this.destination, that.destination); + } + return false; } @Override public int hashCode() { - return this.id.hashCode(); + return Objects.hash(this.id, this.destination); } @Override public String toString() { - return "subscription(id=" + this.id + ")"; + return "subscription(id=" + this.id + "; destination=" + this.destination + ")"; } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java index 18970dfbb5c5..d453f3e43aca 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java @@ -110,6 +110,27 @@ public void registerSameSubscriptionTwice() { assertThat(actual.size()).isEqualTo(1); assertThat(actual.get(sessId)).containsExactly(subId); } + + @Test + public void registerSameSubscriptionForDifferentDestinations() { + String sessId = "sess01"; + String subId = "subs01"; + String dest1 = "/foo"; + String dest2 = "user/foo"; + + this.registry.registerSubscription(subscribeMessage(sessId, subId, dest1)); + this.registry.registerSubscription(subscribeMessage(sessId, subId, dest2)); + + MultiValueMap actual = this.registry.findSubscriptions(createMessage(dest1)); + assertThat(actual).isNotNull(); + assertThat(actual.size()).isEqualTo(1); + assertThat(actual.get(sessId)).containsExactly(subId); + + actual = this.registry.findSubscriptions(createMessage(dest2)); + assertThat(actual).isNotNull(); + assertThat(actual.size()).isEqualTo(1); + assertThat(actual.get(sessId)).containsExactly(subId); + } @Test public void registerSubscriptionMultipleSessions() { From 464416827c636add7de5c08da815c44d70a0b3d8 Mon Sep 17 00:00:00 2001 From: Alex Nistico Date: Sat, 29 May 2021 14:34:48 +0800 Subject: [PATCH 2/2] Make method name plural --- .../messaging/simp/broker/DefaultSubscriptionRegistry.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java index 8e0dd2309185..aeba8a3114f2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java @@ -216,7 +216,7 @@ protected MultiValueMap findSubscriptionsInternal(String destina SessionInfo info = this.sessionRegistry.getSession(sessionId); if (info != null) { for (String subscriptionId : subscriptionIds) { - info.getSubscription(subscriptionId).forEach(subscription -> { + info.getSubscriptions(subscriptionId).forEach(subscription -> { if (evaluateExpression(subscription.getSelector(), message)) { result.add(sessionId, subscription.getId()); } @@ -420,7 +420,7 @@ public Stream getSubscriptions() { return this.subscriptionMap.values().stream().flatMap(Set::stream); } - public Stream getSubscription(String subscriptionId) { + public Stream getSubscriptions(String subscriptionId) { return this.subscriptionMap.getOrDefault(subscriptionId, emptySet()).stream(); }