diff --git a/spring-messaging/spring-messaging.gradle b/spring-messaging/spring-messaging.gradle index b1e9bae989e8..8b5233b1c7e2 100644 --- a/spring-messaging/spring-messaging.gradle +++ b/spring-messaging/spring-messaging.gradle @@ -34,4 +34,5 @@ dependencies { testRuntime("com.sun.xml.bind:jaxb-core") testRuntime("com.sun.xml.bind:jaxb-impl") testRuntime("com.sun.activation:javax.activation") + testRuntime(project(":spring-context")) } diff --git a/spring-messaging/src/jmh/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryBenchmark.java b/spring-messaging/src/jmh/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryBenchmark.java new file mode 100644 index 000000000000..d33412bbab47 --- /dev/null +++ b/spring-messaging/src/jmh/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryBenchmark.java @@ -0,0 +1,192 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.broker; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.infra.Blackhole; + +import org.springframework.messaging.Message; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.util.MultiValueMap; + +@BenchmarkMode(Mode.Throughput) +public class DefaultSubscriptionRegistryBenchmark { + + @State(Scope.Benchmark) + public static class ServerState { + @Param("1000") + public int sessions; + + @Param("10") + public int destinations; + + @Param({"0", "1024"}) + int cacheSizeLimit; + + @Param({"none", "patternSubscriptions", "selectorHeaders"}) + String specialization; + + public DefaultSubscriptionRegistry registry; + + public String[] destinationIds; + + public String[] sessionIds; + + public AtomicInteger uniqueIdGenerator; + + public Message findMessage; + + @Setup(Level.Trial) + public void doSetup() { + this.findMessage = MessageBuilder.createMessage("", SimpMessageHeaderAccessor.create().getMessageHeaders()); + this.uniqueIdGenerator = new AtomicInteger(); + + this.registry = new DefaultSubscriptionRegistry(); + this.registry.setCacheLimit(this.cacheSizeLimit); + this.registry.setSelectorHeaderName("selectorHeaders".equals(this.specialization) ? "someSelector" : null); + + this.destinationIds = IntStream.range(0, this.destinations) + .mapToObj(i -> "/some/destination/" + i) + .toArray(String[]::new); + + this.sessionIds = IntStream.range(0, this.sessions) + .mapToObj(i -> "sessionId_" + i) + .toArray(String[]::new); + + for (String sessionId : this.sessionIds) { + for (String destinationId : this.destinationIds) { + registerSubscriptions(sessionId, destinationId); + } + } + } + + public void registerSubscriptions(String sessionId, String destination) { + if ("patternSubscriptions".equals(this.specialization)) { + destination = "/**/" + destination; + } + String subscriptionId = "subscription_" + this.uniqueIdGenerator.incrementAndGet(); + this.registry.registerSubscription(subscribeMessage(sessionId, subscriptionId, destination)); + } + } + + @State(Scope.Thread) + public static class Requests { + @Param({"none", "sameDestination", "sameSession"}) + String contention; + + public String session; + + public Message subscribe; + + public String findDestination; + + public Message unsubscribe; + + @Setup(Level.Trial) + public void doSetup(ServerState serverState) { + int uniqueNumber = serverState.uniqueIdGenerator.incrementAndGet(); + + if ("sameDestination".equals(this.contention)) { + this.findDestination = serverState.destinationIds[0]; + } + else { + this.findDestination = serverState.destinationIds[uniqueNumber % serverState.destinationIds.length]; + } + + if ("sameSession".equals(this.contention)) { + this.session = serverState.sessionIds[0]; + } + else { + this.session = serverState.sessionIds[uniqueNumber % serverState.sessionIds.length]; + } + + String subscription = String.valueOf(uniqueNumber); + String subscribeDestination = "patternSubscriptions".equals(serverState.specialization) ? + "/**/" + this.findDestination : this.findDestination; + this.subscribe = subscribeMessage(this.session, subscription, subscribeDestination); + + this.unsubscribe = unsubscribeMessage(this.session, subscription); + } + } + + @State(Scope.Thread) + public static class FindRequest { + @Param({"none", "noSubscribers", "sameDestination"}) + String contention; + + public String destination; + + @Setup(Level.Trial) + public void doSetup(ServerState serverState) { + switch (this.contention) { + case "noSubscribers": + this.destination = "someDestination_withNoSubscribers_" + serverState.uniqueIdGenerator.incrementAndGet(); + break; + case "sameDestination": + this.destination = serverState.destinationIds[0]; + break; + case "none": + int uniqueNumber = serverState.uniqueIdGenerator.getAndIncrement(); + this.destination = serverState.destinationIds[uniqueNumber % serverState.destinationIds.length]; + break; + default: + throw new IllegalStateException(); + } + } + } + + @Benchmark + public void registerUnregister(ServerState serverState, Requests request, Blackhole blackhole) { + serverState.registry.registerSubscription(request.subscribe); + blackhole.consume(serverState.registry.findSubscriptionsInternal(request.findDestination, serverState.findMessage)); + serverState.registry.unregisterSubscription(request.unsubscribe); + blackhole.consume(serverState.registry.findSubscriptionsInternal(request.findDestination, serverState.findMessage)); + } + + @Benchmark + public MultiValueMap find(ServerState serverState, FindRequest request) { + return serverState.registry.findSubscriptionsInternal(request.destination, serverState.findMessage); + } + + public static Message subscribeMessage(String sessionId, String subscriptionId, String dest) { + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE); + accessor.setSessionId(sessionId); + accessor.setSubscriptionId(subscriptionId); + accessor.setDestination(dest); + accessor.setNativeHeader("someSelector", "true"); + return MessageBuilder.createMessage("", accessor.getMessageHeaders()); + } + + public static Message unsubscribeMessage(String sessionId, String subscriptionId) { + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE); + accessor.setSessionId(sessionId); + accessor.setSubscriptionId(subscriptionId); + return MessageBuilder.createMessage("", accessor.getMessageHeaders()); + } +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java index 86692ac30d2f..6a8434018b29 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java @@ -16,15 +16,17 @@ package org.springframework.messaging.simp.broker; +import java.util.ArrayList; import java.util.Collection; -import java.util.HashSet; -import java.util.LinkedHashMap; +import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Set; +import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.CopyOnWriteArraySet; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; import org.springframework.expression.EvaluationContext; import org.springframework.expression.Expression; @@ -34,6 +36,7 @@ import org.springframework.expression.spel.SpelEvaluationException; import org.springframework.expression.spel.standard.SpelExpressionParser; import org.springframework.expression.spel.support.SimpleEvaluationContext; +import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; @@ -72,7 +75,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { private PathMatcher pathMatcher = new AntPathMatcher(); - private volatile int cacheLimit = DEFAULT_CACHE_LIMIT; + private int cacheLimit = DEFAULT_CACHE_LIMIT; @Nullable private String selectorHeaderName = "selector"; @@ -106,6 +109,7 @@ public PathMatcher getPathMatcher() { */ public void setCacheLimit(int cacheLimit) { this.cacheLimit = cacheLimit; + this.destinationCache.ensureCacheLimit(); } /** @@ -142,14 +146,17 @@ public String getSelectorHeaderName() { return this.selectorHeaderName; } - @Override - protected void addSubscriptionInternal( - String sessionId, String subsId, String destination, Message message) { - + protected void addSubscriptionInternal(@NonNull String sessionId, @NonNull String subscriptionId, + @NonNull String destination, @NonNull Message message) { Expression expression = getSelectorExpression(message.getHeaders()); - this.subscriptionRegistry.addSubscription(sessionId, subsId, destination, expression); - this.destinationCache.updateAfterNewSubscription(destination, sessionId, subsId); + boolean isAntPattern = this.pathMatcher.isPattern(destination); + Subscription subscription = new Subscription(subscriptionId, expression, destination, isAntPattern); + + Subscription previousValue = this.subscriptionRegistry.addSubscription(sessionId, subscriptionId, subscription); + if (previousValue == null) { + this.destinationCache.updateAfterNewSubscription(destination, isAntPattern, sessionId, subscriptionId); + } } @Nullable @@ -179,9 +186,9 @@ private Expression getSelectorExpression(MessageHeaders headers) { protected void removeSubscriptionInternal(String sessionId, String subsId, Message message) { SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId); if (info != null) { - String destination = info.removeSubscription(subsId); - if (destination != null) { - this.destinationCache.updateAfterRemovedSubscription(sessionId, subsId); + Subscription subscription = info.removeSubscription(subsId); + if (subscription != null) { + this.destinationCache.updateAfterRemovedSubscription(sessionId, subscription); } } } @@ -190,13 +197,13 @@ protected void removeSubscriptionInternal(String sessionId, String subsId, Messa public void unregisterAllSubscriptions(String sessionId) { SessionSubscriptionInfo info = this.subscriptionRegistry.removeSubscriptions(sessionId); if (info != null) { - this.destinationCache.updateAfterRemovedSession(info); + this.destinationCache.updateAfterRemovedSession(sessionId, info.getSubscriptions()); } } @Override protected MultiValueMap findSubscriptionsInternal(String destination, Message message) { - MultiValueMap result = this.destinationCache.getSubscriptions(destination, message); + MultiValueMap result = this.destinationCache.getSubscriptions(destination); return filterSubscriptions(result, message); } @@ -207,168 +214,181 @@ private MultiValueMap filterSubscriptions( return allMatches; } MultiValueMap result = new LinkedMultiValueMap<>(allMatches.size()); - allMatches.forEach((sessionId, subIds) -> { - for (String subId : subIds) { - SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId); - if (info == null) { - continue; - } - Subscription sub = info.getSubscription(subId); - if (sub == null) { - continue; - } - Expression expression = sub.getSelectorExpression(); - if (expression == null) { - result.add(sessionId, subId); - continue; - } - try { - if (Boolean.TRUE.equals(expression.getValue(messageEvalContext, message, Boolean.class))) { - result.add(sessionId, subId); - } - } - catch (SpelEvaluationException ex) { - if (logger.isDebugEnabled()) { - logger.debug("Failed to evaluate selector: " + ex.getMessage()); + allMatches.forEach((sessionId, subscriptionsIds) -> { + SessionSubscriptionInfo subscriptions = this.subscriptionRegistry.getSubscriptions(sessionId); + if (subscriptions != null) { + for (String subscriptionId : subscriptionsIds) { + Subscription subscription = subscriptions.getSubscription(subscriptionId); + if (subscription != null && evaluateExpression(subscription.getSelectorExpression(), message)) { + result.add(sessionId, subscription.getId()); } } - catch (Throwable ex) { - logger.debug("Failed to evaluate selector", ex); - } } }); + return result; } - @Override - public String toString() { - return "DefaultSubscriptionRegistry[" + this.destinationCache + ", " + this.subscriptionRegistry + "]"; + private boolean evaluateExpression(@Nullable Expression expression, Message message) { + boolean result = false; + try { + if (expression == null || Boolean.TRUE.equals(expression.getValue(messageEvalContext, message, Boolean.class))) { + result = true; + } + } + catch (SpelEvaluationException ex) { + if (logger.isDebugEnabled()) { + logger.debug("Failed to evaluate selector: " + ex.getMessage()); + } + } + catch (Throwable ex) { + logger.debug("Failed to evaluate selector", ex); + } + return result; } - /** * A cache for destinations previously resolved via * {@link DefaultSubscriptionRegistry#findSubscriptionsInternal(String, Message)}. */ - private class DestinationCache { + private final class DestinationCache { /** Map from destination to {@code } for fast look-ups. */ - private final Map> accessCache = + private final Map> destinationCache = new ConcurrentHashMap<>(DEFAULT_CACHE_LIMIT); - /** Map from destination to {@code } with locking. */ - @SuppressWarnings("serial") - private final Map> updateCache = - new LinkedHashMap>(DEFAULT_CACHE_LIMIT, 0.75f, true) { - @Override - protected boolean removeEldestEntry(Map.Entry> eldest) { - if (size() > getCacheLimit()) { - accessCache.remove(eldest.getKey()); - return true; - } - else { - return false; - } - } - }; - - - public LinkedMultiValueMap getSubscriptions(String destination, Message message) { - LinkedMultiValueMap result = this.accessCache.get(destination); - if (result == null) { - synchronized (this.updateCache) { - result = new LinkedMultiValueMap<>(); - for (SessionSubscriptionInfo info : subscriptionRegistry.getAllSubscriptions()) { - for (String destinationPattern : info.getDestinations()) { - if (getPathMatcher().match(destinationPattern, destination)) { - for (Subscription sub : info.getSubscriptions(destinationPattern)) { - result.add(info.sessionId, sub.getId()); - } - } - } - } - if (!result.isEmpty()) { - this.updateCache.put(destination, result.deepCopy()); - this.accessCache.put(destination, result); + private final Queue cacheEvictionPolicy = new ConcurrentLinkedQueue<>(); + + private final AtomicInteger cacheSize = new AtomicInteger(); + + public LinkedMultiValueMap getSubscriptions(String destination) { + LinkedMultiValueMap subscriptions = this.destinationCache.get(destination); + if (subscriptions == null) { + subscriptions = this.destinationCache.computeIfAbsent(destination, dest -> { + LinkedMultiValueMap sessionSubscriptions = calculateSubscriptions(destination); + this.cacheEvictionPolicy.add(destination); + this.cacheSize.incrementAndGet(); + return sessionSubscriptions; + }); + ensureCacheLimit(); + } + return subscriptions; + } + + @NonNull + private LinkedMultiValueMap calculateSubscriptions(String destination) { + LinkedMultiValueMap sessionsToSubscriptions = new LinkedMultiValueMap<>(); + + DefaultSubscriptionRegistry.this.subscriptionRegistry.forEachSubscription((sessionId, subscriptionDetail) -> { + if (subscriptionDetail.isAntPattern()) { + if (pathMatcher.match(subscriptionDetail.getDestination(), destination)) { + sessionsToSubscriptions.compute(sessionId, (s, subscriptions) -> + addToList(subscriptionDetail.getId(), subscriptions)); } } + else if (destination.equals(subscriptionDetail.getDestination())) { + sessionsToSubscriptions.compute(sessionId, (s, subscriptions) -> + addToList(subscriptionDetail.getId(), subscriptions)); + } + }); + return sessionsToSubscriptions; + } + + @NonNull + private List addToList(String subscriptionId, @Nullable List subscriptions) { + if (subscriptions == null) { + return Collections.singletonList(subscriptionId); + } + else { + List newSubscriptions = new ArrayList<>(subscriptions.size() + 1); + newSubscriptions.addAll(subscriptions); + newSubscriptions.add(subscriptionId); + return newSubscriptions; } - return result; - } - - public void updateAfterNewSubscription(String destination, String sessionId, String subsId) { - synchronized (this.updateCache) { - this.updateCache.forEach((cachedDestination, subscriptions) -> { - if (getPathMatcher().match(destination, cachedDestination)) { - // Subscription id's may also be populated via getSubscriptions() - List subsForSession = subscriptions.get(sessionId); - if (subsForSession == null || !subsForSession.contains(subsId)) { - subscriptions.add(sessionId, subsId); - this.accessCache.put(cachedDestination, subscriptions.deepCopy()); - } + } + + private void ensureCacheLimit() { + int size = this.cacheSize.get(); + if (size > cacheLimit) { + do { + if (this.cacheSize.compareAndSet(size, size - 1)) { + this.destinationCache.remove(this.cacheEvictionPolicy.poll()); } - }); + } while ((size = this.cacheSize.get()) > cacheLimit); } } - public void updateAfterRemovedSubscription(String sessionId, String subsId) { - synchronized (this.updateCache) { - Set destinationsToRemove = new HashSet<>(); - this.updateCache.forEach((destination, sessionMap) -> { - List subscriptions = sessionMap.get(sessionId); - if (subscriptions != null) { - subscriptions.remove(subsId); - if (subscriptions.isEmpty()) { - sessionMap.remove(sessionId); - } - if (sessionMap.isEmpty()) { - destinationsToRemove.add(destination); - } - else { - this.accessCache.put(destination, sessionMap.deepCopy()); - } + public void updateAfterNewSubscription(String destination, boolean isPattern, String sessionId, String subscriptionId) { + if (isPattern) { + for (String cachedDestination : this.destinationCache.keySet()) { + if (pathMatcher.match(destination, cachedDestination)) { + addToDestination(cachedDestination, sessionId, subscriptionId); } - }); - for (String destination : destinationsToRemove) { - this.updateCache.remove(destination); - this.accessCache.remove(destination); } } + else { + addToDestination(destination, sessionId, subscriptionId); + } } - public void updateAfterRemovedSession(SessionSubscriptionInfo info) { - synchronized (this.updateCache) { - Set destinationsToRemove = new HashSet<>(); - this.updateCache.forEach((destination, sessionMap) -> { - if (sessionMap.remove(info.getSessionId()) != null) { - if (sessionMap.isEmpty()) { - destinationsToRemove.add(destination); - } - else { - this.accessCache.put(destination, sessionMap.deepCopy()); - } + private void addToDestination(String destination, String sessionId, String subscriptionId) { + this.destinationCache.computeIfPresent(destination, (dest, sessionsToSubscriptions) -> { + sessionsToSubscriptions = sessionsToSubscriptions.clone(); + sessionsToSubscriptions.compute(sessionId, (s, subscriptions) -> addToList(subscriptionId, subscriptions)); + return sessionsToSubscriptions; + }); + } + + public void updateAfterRemovedSubscription(String sessionId, Subscription subscriptionDetail) { + if (subscriptionDetail.isAntPattern()) { + String patternDestination = subscriptionDetail.getDestination(); + for (String destination : this.destinationCache.keySet()) { + if (pathMatcher.match(patternDestination, destination)) { + removeInternal(destination, sessionId, subscriptionDetail.getId()); } - }); - for (String destination : destinationsToRemove) { - this.updateCache.remove(destination); - this.accessCache.remove(destination); } } + else { + removeInternal(subscriptionDetail.getDestination(), sessionId, subscriptionDetail.getId()); + } } - @Override - public String toString() { - return "cache[" + this.accessCache.size() + " destination(s)]"; + private void removeInternal(String destination, String sessionId, String subscription) { + this.destinationCache.computeIfPresent(destination, (dest, subscriptions) -> { + subscriptions = subscriptions.clone(); + subscriptions.computeIfPresent(sessionId, (session, subs) -> { + /* it is very likely that one session has only one subscription per one destination */ + if (subs.size() == 1 && subscription.equals(subs.get(0))) { + return null; + } + else { + subs = new ArrayList<>(subs); + subs.remove(subscription); + return emptyListToNUll(subs); + } + }); + return subscriptions; + }); + } + + @Nullable + private List emptyListToNUll(@NonNull List list) { + return list.isEmpty() ? null : list; } - } + public void updateAfterRemovedSession(String sessionId, Collection subscriptionDetails) { + for (Subscription subscriptionDetail : subscriptionDetails) { + updateAfterRemovedSubscription(sessionId, subscriptionDetail); + } + } + } /** * Provide access to session subscriptions by sessionId. */ - private static class SessionSubscriptionRegistry { + private static final class SessionSubscriptionRegistry { - // sessionId -> SessionSubscriptionInfo + // 'sessionId' -> 'subscriptionId' -> 'destination, selector expression' private final ConcurrentMap sessions = new ConcurrentHashMap<>(); @Nullable @@ -376,119 +396,51 @@ public SessionSubscriptionInfo getSubscriptions(String sessionId) { return this.sessions.get(sessionId); } - public Collection getAllSubscriptions() { - return this.sessions.values(); + public void forEachSubscription(BiConsumer consumer) { + this.sessions.forEach((sessionId, subscriptions) -> + subscriptions.getSubscriptions().forEach(subscriptionDetail -> + consumer.accept(sessionId, subscriptionDetail))); } - public SessionSubscriptionInfo addSubscription(String sessionId, String subscriptionId, - String destination, @Nullable Expression selectorExpression) { - - SessionSubscriptionInfo info = this.sessions.get(sessionId); - if (info == null) { - info = new SessionSubscriptionInfo(sessionId); - SessionSubscriptionInfo value = this.sessions.putIfAbsent(sessionId, info); - if (value != null) { - info = value; - } - } - info.addSubscription(destination, subscriptionId, selectorExpression); - return info; + @Nullable + public Subscription addSubscription(String sessionId, String subscriptionId, Subscription subscriptionDetail) { + SessionSubscriptionInfo subscriptions = this.sessions.computeIfAbsent(sessionId, s -> new SessionSubscriptionInfo()); + return subscriptions.addSubscription(subscriptionId, subscriptionDetail); } @Nullable public SessionSubscriptionInfo removeSubscriptions(String sessionId) { return this.sessions.remove(sessionId); } - - @Override - public String toString() { - return "registry[" + this.sessions.size() + " sessions]"; - } } - /** * Hold subscriptions for a session. */ - private static class SessionSubscriptionInfo { - - private final String sessionId; - - // destination -> subscriptions - private final Map> destinationLookup = new ConcurrentHashMap<>(4); - - public SessionSubscriptionInfo(String sessionId) { - Assert.notNull(sessionId, "'sessionId' must not be null"); - this.sessionId = sessionId; - } + private static final class SessionSubscriptionInfo { - public String getSessionId() { - return this.sessionId; - } - - public Set getDestinations() { - return this.destinationLookup.keySet(); - } + private final Map subscriptionLookup = new ConcurrentHashMap<>(); - public Set getSubscriptions(String destination) { - return this.destinationLookup.get(destination); + public Collection getSubscriptions() { + return this.subscriptionLookup.values(); } @Nullable public Subscription getSubscription(String subscriptionId) { - for (Map.Entry> destinationEntry : - this.destinationLookup.entrySet()) { - for (Subscription sub : destinationEntry.getValue()) { - if (sub.getId().equalsIgnoreCase(subscriptionId)) { - return sub; - } - } - } - return null; - } - - public void addSubscription(String destination, String subscriptionId, @Nullable Expression selectorExpression) { - Set subs = this.destinationLookup.get(destination); - if (subs == null) { - synchronized (this.destinationLookup) { - subs = this.destinationLookup.get(destination); - if (subs == null) { - subs = new CopyOnWriteArraySet<>(); - this.destinationLookup.put(destination, subs); - } - } - } - subs.add(new Subscription(subscriptionId, selectorExpression)); + return this.subscriptionLookup.get(subscriptionId); } @Nullable - public String removeSubscription(String subscriptionId) { - for (Map.Entry> destinationEntry : - this.destinationLookup.entrySet()) { - Set subs = destinationEntry.getValue(); - if (subs != null) { - for (Subscription sub : subs) { - if (sub.getId().equals(subscriptionId) && subs.remove(sub)) { - synchronized (this.destinationLookup) { - if (subs.isEmpty()) { - this.destinationLookup.remove(destinationEntry.getKey()); - } - } - return destinationEntry.getKey(); - } - } - } - } - return null; + public Subscription addSubscription(String subscriptionId, Subscription subscriptionDetail) { + return this.subscriptionLookup.putIfAbsent(subscriptionId, subscriptionDetail); } - @Override - public String toString() { - return "[sessionId=" + this.sessionId + ", subscriptions=" + this.destinationLookup + "]"; + @Nullable + public Subscription removeSubscription(String subscriptionId) { + return this.subscriptionLookup.remove(subscriptionId); } } - private static final class Subscription { private final String id; @@ -496,16 +448,31 @@ private static final class Subscription { @Nullable private final Expression selectorExpression; - public Subscription(String id, @Nullable Expression selector) { + private final String destination; + + private final boolean isAntPattern; + + public Subscription(String id, @Nullable Expression selector, String destination, boolean isAntPattern) { Assert.notNull(id, "Subscription id must not be null"); + Assert.notNull(destination, "Subscription destination must not be null"); this.id = id; this.selectorExpression = selector; + this.destination = destination; + this.isAntPattern = isAntPattern; } public String getId() { return this.id; } + public String getDestination() { + return this.destination; + } + + public boolean isAntPattern() { + return this.isAntPattern; + } + @Nullable public Expression getSelectorExpression() { return this.selectorExpression; diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java index b0ce3f561021..18970dfbb5c5 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java @@ -96,6 +96,21 @@ public void registerSubscriptionOneSession() { assertThat(sort(actual.get(sessId))).isEqualTo(subscriptionIds); } + @Test + public void registerSameSubscriptionTwice() { + String sessId = "sess01"; + String subId = "subs01"; + String dest = "/foo"; + + this.registry.registerSubscription(subscribeMessage(sessId, subId, dest)); + this.registry.registerSubscription(subscribeMessage(sessId, subId, dest)); + + MultiValueMap actual = this.registry.findSubscriptions(createMessage(dest)); + assertThat(actual).isNotNull(); + assertThat(actual.size()).isEqualTo(1); + assertThat(actual.get(sessId)).containsExactly(subId); + } + @Test public void registerSubscriptionMultipleSessions() { List sessIds = Arrays.asList("sess01", "sess02", "sess03"); @@ -148,7 +163,7 @@ public void registerSubscriptionsWithSimpleAndPatternDestinations() { MultiValueMap actual = this.registry.findSubscriptions(destNasdaqIbmMessage); assertThat(actual).isNotNull(); assertThat(actual.size()).isEqualTo(1); - assertThat(actual.get(sess1)).isEqualTo(Arrays.asList(subs2, subs1)); + assertThat(actual.get(sess1)).containsExactlyInAnyOrder(subs2, subs1); this.registry.registerSubscription(subscribeMessage(sess2, subs1, destNasdaqIbm)); this.registry.registerSubscription(subscribeMessage(sess2, subs2, "/topic/PRICE.STOCK.NYSE.IBM")); @@ -157,7 +172,7 @@ public void registerSubscriptionsWithSimpleAndPatternDestinations() { actual = this.registry.findSubscriptions(destNasdaqIbmMessage); assertThat(actual).isNotNull(); assertThat(actual.size()).isEqualTo(2); - assertThat(actual.get(sess1)).isEqualTo(Arrays.asList(subs2, subs1)); + assertThat(actual.get(sess1)).containsExactlyInAnyOrder(subs2, subs1); assertThat(actual.get(sess2)).isEqualTo(Collections.singletonList(subs1)); this.registry.unregisterAllSubscriptions(sess1); @@ -173,7 +188,7 @@ public void registerSubscriptionsWithSimpleAndPatternDestinations() { actual = this.registry.findSubscriptions(destNasdaqIbmMessage); assertThat(actual).isNotNull(); assertThat(actual.size()).isEqualTo(2); - assertThat(actual.get(sess1)).isEqualTo(Arrays.asList(subs1, subs2)); + assertThat(actual.get(sess1)).containsExactlyInAnyOrder(subs1, subs2); assertThat(actual.get(sess2)).isEqualTo(Collections.singletonList(subs1)); this.registry.unregisterSubscription(unsubscribeMessage(sess1, subs2));