diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java index 35742c2f7..aca390246 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/AbstractStateMachine.java @@ -824,7 +824,7 @@ private void switchToState(State state, Message message, Transition setCurrentState(toState, message, transition, true, stateMachine, null, targets); } - callPostStateChangeInterceptors(state, message, transition, stateMachine); + callPostStateChangeInterceptors(toState, message, transition, stateMachine); stateMachineExecutor.execute(); if (isComplete()) { diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/support/StateChangeInterceptorTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/support/StateChangeInterceptorTests.java index 79dde986c..15899a035 100644 --- a/spring-statemachine-core/src/test/java/org/springframework/statemachine/support/StateChangeInterceptorTests.java +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/support/StateChangeInterceptorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2015 the original author or authors. + * Copyright 2015-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertThat; +import java.util.ArrayList; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -28,6 +29,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.statemachine.AbstractStateMachineTests; import org.springframework.statemachine.StateContext; import org.springframework.statemachine.StateMachine; @@ -164,6 +166,82 @@ public void apply(StateMachineAccess function) { assertThat(interceptor.preStateChangeCount, is(1)); } + @Test + public void testIntercept4() throws InterruptedException { + context.register(Config4.class); + context.refresh(); + @SuppressWarnings("unchecked") + StateMachine machine = context.getBean(StateMachineSystemConstants.DEFAULT_ID_STATEMACHINE, StateMachine.class); + TestListener listener = new TestListener(); + machine.addStateListener(listener); + TestStateChangeInterceptor interceptor = new TestStateChangeInterceptor(); + + machine.getStateMachineAccessor().doWithRegion(new StateMachineFunction>() { + + @Override + public void apply(StateMachineAccess function) { + function.addStateMachineInterceptor(interceptor); + } + }); + + machine.start(); + assertThat(listener.stateChangedLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateChangedCount, is(1)); + assertThat(machine.getState().getIds(), containsInAnyOrder(States.S0)); + + interceptor.reset(1); + listener.reset(1); + machine.sendEvent(Events.A); + assertThat(listener.stateChangedLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateChangedCount, is(1)); + assertThat(machine.getState().getIds(), containsInAnyOrder(States.S2)); + assertThat(interceptor.preStateChangeLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(interceptor.preStateChangeCount, is(1)); + assertThat(interceptor.postStateChangeLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(interceptor.postStateChangeCount, is(1)); + assertThat(interceptor.preStateChangeStates.size(), is(1)); + assertThat(interceptor.postStateChangeStates.size(), is(1)); + assertThat(interceptor.preStateChangeStates.get(0).getId(), is(interceptor.postStateChangeStates.get(0).getId())); + } + + @Test + public void testIntercept5() throws InterruptedException { + context.register(Config4.class); + context.refresh(); + @SuppressWarnings("unchecked") + StateMachine machine = context.getBean(StateMachineSystemConstants.DEFAULT_ID_STATEMACHINE, StateMachine.class); + TestListener listener = new TestListener(); + machine.addStateListener(listener); + TestStateChangeInterceptor interceptor = new TestStateChangeInterceptor(); + + machine.getStateMachineAccessor().doWithRegion(new StateMachineFunction>() { + + @Override + public void apply(StateMachineAccess function) { + function.addStateMachineInterceptor(interceptor); + } + }); + + machine.start(); + assertThat(listener.stateChangedLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateChangedCount, is(1)); + assertThat(machine.getState().getIds(), containsInAnyOrder(States.S0)); + + interceptor.reset(1); + listener.reset(1); + machine.sendEvent(MessageBuilder.withPayload(Events.A).setHeader("test", "exists").build()); + assertThat(listener.stateChangedLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(listener.stateChangedCount, is(1)); + assertThat(machine.getState().getIds(), containsInAnyOrder(States.S3)); + assertThat(interceptor.preStateChangeLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(interceptor.preStateChangeCount, is(1)); + assertThat(interceptor.postStateChangeLatch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(interceptor.postStateChangeCount, is(1)); + assertThat(interceptor.preStateChangeStates.size(), is(1)); + assertThat(interceptor.postStateChangeStates.size(), is(1)); + assertThat(interceptor.preStateChangeStates.get(0).getId(), is(interceptor.postStateChangeStates.get(0).getId())); + } + @Configuration @EnableStateMachine static class Config1 extends EnumStateMachineConfigurerAdapter { @@ -341,8 +419,43 @@ public void configure(StateMachineTransitionConfigurer transitio } } + @Configuration + @EnableStateMachine + static class Config4 extends EnumStateMachineConfigurerAdapter { + + @Override + public void configure(StateMachineStateConfigurer states) + throws Exception { + states + .withStates() + .initial(States.S0) + .choice(States.S1) + .state(States.S2) + .state(States.S3); + } + + @Override + public void configure(StateMachineTransitionConfigurer transitions) + throws Exception { + transitions + .withExternal() + .source(States.S0).target(States.S1) + .event(Events.A) + .and() + .withChoice() + .source(States.S1) + .first(States.S3, guard()) + .last(States.S2); + } + + @Bean + public EventHeaderGuard guard() { + return new EventHeaderGuard("test"); + } + } + public static enum States { - S0, S1, S11, S12, S2, S21, S211, S212 + S0, S1, S11, S12, S2, S21, S211, S212, S3; } public static enum Events { @@ -380,6 +493,20 @@ public boolean evaluate(StateContext context) { } } + private static class EventHeaderGuard implements Guard { + + private final String header; + + public EventHeaderGuard(String header) { + this.header = header; + } + + @Override + public boolean evaluate(StateContext context) { + return context.getMessageHeader(header) != null; + } + } + private static class TestListener extends StateMachineListenerAdapter { volatile CountDownLatch stateChangedLatch = new CountDownLatch(1); @@ -401,7 +528,11 @@ public void reset(int c1) { private static class TestStateChangeInterceptor implements StateMachineInterceptor { volatile CountDownLatch preStateChangeLatch = new CountDownLatch(1); + volatile CountDownLatch postStateChangeLatch = new CountDownLatch(1); volatile int preStateChangeCount = 0; + volatile int postStateChangeCount = 0; + ArrayList> preStateChangeStates = new ArrayList<>(); + ArrayList> postStateChangeStates = new ArrayList<>(); @Override public Message preEvent(Message message, StateMachine stateMachine) { @@ -411,6 +542,7 @@ public Message preEvent(Message message, StateMachine state, Message message, Transition transition, StateMachine stateMachine) { + preStateChangeStates.add(state); preStateChangeCount++; preStateChangeLatch.countDown(); @@ -419,6 +551,9 @@ public void preStateChange(State state, Message message, @Override public void postStateChange(State state, Message message, Transition transition, StateMachine stateMachine) { + postStateChangeStates.add(state); + postStateChangeCount++; + postStateChangeLatch.countDown(); } @Override @@ -435,6 +570,10 @@ public StateContext postTransition(StateContext public void reset(int c1) { preStateChangeLatch = new CountDownLatch(c1); preStateChangeCount = 0; + postStateChangeLatch = new CountDownLatch(c1); + postStateChangeCount = 0; + preStateChangeStates.clear(); + postStateChangeStates.clear(); } @Override