diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java index 94eb1383eab9..40cae8cb3294 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java @@ -16,11 +16,6 @@ package org.springframework.messaging.simp.stomp; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; - import org.springframework.messaging.simp.SimpMessageType; /** @@ -32,62 +27,55 @@ public enum StompCommand { // client - CONNECT, - STOMP, - DISCONNECT, - SUBSCRIBE, - UNSUBSCRIBE, - SEND, - ACK, - NACK, - BEGIN, - COMMIT, - ABORT, + CONNECT(SimpMessageType.CONNECT, 0), + STOMP(SimpMessageType.CONNECT, 0), + DISCONNECT(SimpMessageType.DISCONNECT, 0), + SUBSCRIBE(SimpMessageType.SUBSCRIBE, 3), + UNSUBSCRIBE(SimpMessageType.UNSUBSCRIBE, 2), + SEND(SimpMessageType.MESSAGE, 13), + ACK(SimpMessageType.OTHER, 0), + NACK(SimpMessageType.OTHER, 0), + BEGIN(SimpMessageType.OTHER, 0), + COMMIT(SimpMessageType.OTHER, 0), + ABORT(SimpMessageType.OTHER, 0), // server - CONNECTED, - MESSAGE, - RECEIPT, - ERROR; - - - private static Map messageTypes = new HashMap<>(); - static { - messageTypes.put(StompCommand.CONNECT, SimpMessageType.CONNECT); - messageTypes.put(StompCommand.STOMP, SimpMessageType.CONNECT); - messageTypes.put(StompCommand.SEND, SimpMessageType.MESSAGE); - messageTypes.put(StompCommand.MESSAGE, SimpMessageType.MESSAGE); - messageTypes.put(StompCommand.SUBSCRIBE, SimpMessageType.SUBSCRIBE); - messageTypes.put(StompCommand.UNSUBSCRIBE, SimpMessageType.UNSUBSCRIBE); - messageTypes.put(StompCommand.DISCONNECT, SimpMessageType.DISCONNECT); + CONNECTED(SimpMessageType.OTHER, 0), + MESSAGE(SimpMessageType.MESSAGE, 15), + RECEIPT(SimpMessageType.OTHER, 0), + ERROR(SimpMessageType.OTHER, 12); + + private static final int DESTINATION_REQUIRED = 1; + private static final int SUBSCRIPTION_ID_REQUIRED = 2; + private static final int CONTENT_LENGTH_REQUIRED = 4; + private static final int BODY_ALLOWED = 8; + + private final SimpMessageType simpMessageType; + private final int flags; + + StompCommand(final SimpMessageType simpMessageType, final int flags) { + this.simpMessageType = simpMessageType; + this.flags = flags; } - private static Collection destinationRequired = Arrays.asList(SEND, SUBSCRIBE, MESSAGE); - private static Collection subscriptionIdRequired = Arrays.asList(SUBSCRIBE, UNSUBSCRIBE, MESSAGE); - private static Collection contentLengthRequired = Arrays.asList(SEND, MESSAGE, ERROR); - private static Collection bodyAllowed = Arrays.asList(SEND, MESSAGE, ERROR); - - - public SimpMessageType getMessageType() { - SimpMessageType type = messageTypes.get(this); - return (type != null) ? type : SimpMessageType.OTHER; + return simpMessageType; } public boolean requiresDestination() { - return destinationRequired.contains(this); + return (flags & DESTINATION_REQUIRED) != 0; } public boolean requiresSubscriptionId() { - return subscriptionIdRequired.contains(this); + return (flags & SUBSCRIPTION_ID_REQUIRED) != 0; } public boolean requiresContentLength() { - return contentLengthRequired.contains(this); + return (flags & CONTENT_LENGTH_REQUIRED) != 0; } public boolean isBodyAllowed() { - return bodyAllowed.contains(this); + return (flags & BODY_ALLOWED) != 0; } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCommandTest.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCommandTest.java new file mode 100644 index 000000000000..dc916d064e6d --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCommandTest.java @@ -0,0 +1,82 @@ +package org.springframework.messaging.simp.stomp; + +import org.junit.Test; +import org.springframework.messaging.simp.SimpMessageType; + +import java.util.Arrays; +import java.util.Collection; +import java.util.EnumMap; +import java.util.Map; + +import static org.junit.Assert.*; + +public class StompCommandTest { + + private static final Collection destinationRequired = Arrays.asList(StompCommand.SEND, StompCommand.SUBSCRIBE, StompCommand.MESSAGE); + private static final Collection subscriptionIdRequired = Arrays.asList(StompCommand.SUBSCRIBE, StompCommand.UNSUBSCRIBE, StompCommand.MESSAGE); + private static final Collection contentLengthRequired = Arrays.asList(StompCommand.SEND, StompCommand.MESSAGE, StompCommand.ERROR); + private static final Collection bodyAllowed = Arrays.asList(StompCommand.SEND, StompCommand.MESSAGE, StompCommand.ERROR); + + private static final Map messageTypes = new EnumMap<>(StompCommand.class); + + static { + messageTypes.put(StompCommand.CONNECT, SimpMessageType.CONNECT); + messageTypes.put(StompCommand.STOMP, SimpMessageType.CONNECT); + messageTypes.put(StompCommand.SEND, SimpMessageType.MESSAGE); + messageTypes.put(StompCommand.MESSAGE, SimpMessageType.MESSAGE); + messageTypes.put(StompCommand.SUBSCRIBE, SimpMessageType.SUBSCRIBE); + messageTypes.put(StompCommand.UNSUBSCRIBE, SimpMessageType.UNSUBSCRIBE); + messageTypes.put(StompCommand.DISCONNECT, SimpMessageType.DISCONNECT); + } + + @Test + public void getMessageType() throws Exception { + for (final Map.Entry stompToSimp : messageTypes.entrySet()) { + assertEquals(stompToSimp.getKey().getMessageType(), stompToSimp.getValue()); + } + } + + @Test + public void requiresDestination() throws Exception { + for (final StompCommand stompCommand : StompCommand.values()) { + if (destinationRequired.contains(stompCommand)) { + assertTrue(stompCommand.requiresDestination()); + } else { + assertFalse(stompCommand.requiresDestination()); + } + } + } + + @Test + public void requiresSubscriptionId() throws Exception { + for (final StompCommand stompCommand : StompCommand.values()) { + if (subscriptionIdRequired.contains(stompCommand)) { + assertTrue(stompCommand.requiresSubscriptionId()); + } else { + assertFalse(stompCommand.requiresSubscriptionId()); + } + } + } + + @Test + public void requiresContentLength() throws Exception { + for (final StompCommand stompCommand : StompCommand.values()) { + if (contentLengthRequired.contains(stompCommand)) { + assertTrue(stompCommand.requiresContentLength()); + } else { + assertFalse(stompCommand.requiresContentLength()); + } + } + } + + @Test + public void isBodyAllowed() throws Exception { + for (final StompCommand stompCommand : StompCommand.values()) { + if (bodyAllowed.contains(stompCommand)) { + assertTrue(stompCommand.isBodyAllowed()); + } else { + assertFalse(stompCommand.isBodyAllowed()); + } + } + } +}