diff --git a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/DefaultStateMachineExecutor.java b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/DefaultStateMachineExecutor.java index 28bbea4f6..ed93e7a54 100644 --- a/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/DefaultStateMachineExecutor.java +++ b/spring-statemachine-core/src/main/java/org/springframework/statemachine/support/DefaultStateMachineExecutor.java @@ -449,7 +449,15 @@ public void triggered() { log.debug("TimedTrigger triggered " + trigger); } triggerQueue.add(new TriggerQueueItem(trigger, null)); - scheduleEventQueueProcessing(); + // isRunning() is also called in scheduleEventQueueProcessing() + // but we may get into lifecycle deadlock if we schedule here + // from a different thread. may happen if timer fires immediately + // and we're not exactly gone through start sequence. + // however this trigger is most likely getting processed as + // it was added to trigger queue. + if (isRunning()) { + scheduleEventQueueProcessing(); + } } }); } diff --git a/spring-statemachine-core/src/test/java/org/springframework/statemachine/support/DefaultStateMachineExecutorTests.java b/spring-statemachine-core/src/test/java/org/springframework/statemachine/support/DefaultStateMachineExecutorTests.java new file mode 100644 index 000000000..f1120108f --- /dev/null +++ b/spring-statemachine-core/src/test/java/org/springframework/statemachine/support/DefaultStateMachineExecutorTests.java @@ -0,0 +1,270 @@ +/* + * Copyright 2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.statemachine.support; + +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; +import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.scheduling.concurrent.ConcurrentTaskScheduler; +import org.springframework.statemachine.StateContext; +import org.springframework.statemachine.StateMachine; +import org.springframework.statemachine.state.State; +import org.springframework.statemachine.support.StateMachineExecutor.StateMachineExecutorTransit; +import org.springframework.statemachine.transition.Transition; +import org.springframework.statemachine.trigger.EventTrigger; +import org.springframework.statemachine.trigger.TimerTrigger; +import org.springframework.statemachine.trigger.Trigger; + +public class DefaultStateMachineExecutorTests { + + @SuppressWarnings("unchecked") + @Test + public void testSimpleExecute() throws Exception { + + SyncTaskExecutor taskExecutor = new SyncTaskExecutor(); + Message message = MessageBuilder.withPayload("E1").build(); + + EventTrigger triggerE1 = new EventTrigger("E1"); + + State stateS1 = mock(State.class); + when(stateS1.getId()).thenReturn("S1"); + when(stateS1.getIds()).thenReturn(Arrays.asList("S1")); + State stateS2 = mock(State.class); + when(stateS2.getId()).thenReturn("S2"); + when(stateS2.getIds()).thenReturn(Arrays.asList("S2")); + + Transition transitionS1S2 = mock(Transition.class); + when(transitionS1S2.getSource()).thenReturn(stateS1); + when(transitionS1S2.getTarget()).thenReturn(stateS2); + when(transitionS1S2.getTrigger()).thenReturn(triggerE1); + when(transitionS1S2.transit(any())).thenReturn(true); + + StateMachine stateMachine = mock(StateMachine.class); + when(stateMachine.getState()).thenReturn(stateS1); + + Collection> transitions = new ArrayList<>(); + transitions.add(transitionS1S2); + + Map, Transition> triggerToTransitionMap = new HashMap<>(); + triggerToTransitionMap.put(triggerE1, transitionS1S2); + + List> triggerlessTransitions = new ArrayList<>(); + + Transition initialTransition = mock(Transition.class); + Message initialEvent = null; + + DefaultStateMachineExecutor executor = new DefaultStateMachineExecutor<>( + stateMachine, + stateMachine, + transitions, + triggerToTransitionMap, + triggerlessTransitions, + initialTransition, + initialEvent); + + executor.setTaskExecutor(taskExecutor); + + TestStateMachineExecutorTransit transit = new TestStateMachineExecutorTransit(); + transit.reset(2); + executor.setStateMachineExecutorTransit(transit); + executor.start(); + + executor.queueEvent(message); + executor.execute(); + + assertThat(transit.latch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(transit.transitions.size(), is(2)); + + } + + @SuppressWarnings("unchecked") + @Test + public void testSimpleTimer() throws Exception { + SyncTaskExecutor taskExecutor = new SyncTaskExecutor(); + ConcurrentTaskScheduler taskScheduler = new ConcurrentTaskScheduler(); + + EventTrigger triggerE1 = new EventTrigger("E1"); + + TimerTrigger triggerTimer = new TimerTrigger<>(1000, 1); + triggerTimer.setTaskScheduler(taskScheduler); + + State stateS1 = mock(State.class); + when(stateS1.getId()).thenReturn("S1"); + when(stateS1.getIds()).thenReturn(Arrays.asList("S1")); + State stateS2 = mock(State.class); + when(stateS1.getId()).thenReturn("S2"); + when(stateS1.getIds()).thenReturn(Arrays.asList("S2")); + State stateS3 = mock(State.class); + when(stateS1.getId()).thenReturn("S3"); + when(stateS1.getIds()).thenReturn(Arrays.asList("S3")); + + Transition transitionS1S2 = mock(Transition.class); + when(transitionS1S2.getSource()).thenReturn(stateS1); + when(transitionS1S2.getTarget()).thenReturn(stateS2); + when(transitionS1S2.getTrigger()).thenReturn(triggerE1); + when(transitionS1S2.transit(any())).thenReturn(true); + + Transition transitionS1S3 = mock(Transition.class); + when(transitionS1S3.getSource()).thenReturn(stateS1); + when(transitionS1S3.getTarget()).thenReturn(stateS3); + when(transitionS1S3.getTrigger()).thenReturn(triggerTimer); + when(transitionS1S3.transit(any())).thenReturn(true); + + + StateMachine stateMachine = mock(StateMachine.class); + when(stateMachine.getState()).thenReturn(stateS1); + + Collection> transitions = new ArrayList<>(); + transitions.add(transitionS1S2); + + Map, Transition> triggerToTransitionMap = new HashMap<>(); + triggerToTransitionMap.put(triggerE1, transitionS1S2); + triggerToTransitionMap.put(triggerTimer, transitionS1S3); + + List> triggerlessTransitions = new ArrayList<>(); + + Transition initialTransition = mock(Transition.class); + Message initialEvent = null; + + DefaultStateMachineExecutor executor = new DefaultStateMachineExecutor<>( + stateMachine, + stateMachine, + transitions, + triggerToTransitionMap, + triggerlessTransitions, + initialTransition, + initialEvent); + + executor.setTaskExecutor(taskExecutor); + + TestStateMachineExecutorTransit transit = new TestStateMachineExecutorTransit(); + transit.reset(2); + executor.setStateMachineExecutorTransit(transit); + executor.start(); + + triggerTimer.start(); + triggerTimer.arm(); + + assertThat(transit.latch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(transit.transitions.size(), is(2)); + } + + @SuppressWarnings("unchecked") + @Test + public void testDeadlock() throws Exception { + // gh-315 + // nasty, with deadlock you can't use junit timeout + // as then test is run on different thread, thus test doesn't fail. + SyncTaskExecutor taskExecutor = new SyncTaskExecutor(); + ConcurrentTaskScheduler taskScheduler = new ConcurrentTaskScheduler(); + + EventTrigger triggerE1 = new EventTrigger("E1"); + + TimerTrigger triggerTimer = new TimerTrigger<>(1000); + triggerTimer.setTaskScheduler(taskScheduler); + + State stateS1 = mock(State.class); + when(stateS1.getId()).thenReturn("S1"); + when(stateS1.getIds()).thenReturn(Arrays.asList("S1")); + State stateS2 = mock(State.class); + when(stateS1.getId()).thenReturn("S2"); + when(stateS1.getIds()).thenReturn(Arrays.asList("S2")); + State stateS3 = mock(State.class); + when(stateS1.getId()).thenReturn("S3"); + when(stateS1.getIds()).thenReturn(Arrays.asList("S3")); + + Transition transitionS1S2 = mock(Transition.class); + when(transitionS1S2.getSource()).thenReturn(stateS1); + when(transitionS1S2.getTarget()).thenReturn(stateS2); + when(transitionS1S2.getTrigger()).thenReturn(triggerE1); + when(transitionS1S2.transit(any())).thenReturn(true); + + Transition transitionS1S3 = mock(Transition.class); + when(transitionS1S3.getSource()).thenReturn(stateS1); + when(transitionS1S3.getTarget()).thenReturn(stateS3); + when(transitionS1S3.getTrigger()).thenReturn(triggerTimer); + when(transitionS1S3.transit(any())).thenReturn(true); + + + StateMachine stateMachine = mock(StateMachine.class); + when(stateMachine.getState()).thenReturn(stateS1); + + Collection> transitions = new ArrayList<>(); + transitions.add(transitionS1S2); + + Map, Transition> triggerToTransitionMap = new HashMap<>(); + triggerToTransitionMap.put(triggerE1, transitionS1S2); + triggerToTransitionMap.put(triggerTimer, transitionS1S3); + + List> triggerlessTransitions = new ArrayList<>(); + + Transition initialTransition = mock(Transition.class); + Message initialEvent = null; + + DefaultStateMachineExecutor executor = new DefaultStateMachineExecutor<>( + stateMachine, + stateMachine, + transitions, + triggerToTransitionMap, + triggerlessTransitions, + initialTransition, + initialEvent); + + executor.setTaskExecutor(taskExecutor); + + TestStateMachineExecutorTransit transit = new TestStateMachineExecutorTransit(); + transit.reset(2); + executor.setStateMachineExecutorTransit(transit); + executor.start(); + + triggerTimer.start(); + assertThat(transit.latch.await(2, TimeUnit.SECONDS), is(true)); + assertThat(transit.transitions.size(), is(2)); + } + + private static class TestStateMachineExecutorTransit implements StateMachineExecutorTransit { + + ArrayList> transitions = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(1); + + @Override + public void transit(Transition transition, StateContext stateContext, Message message) { + transitions.add(transition); + latch.countDown(); + } + + void reset(int i) { + latch = new CountDownLatch(i); + transitions.clear(); + } + } +}