diff --git a/spring-batch-infrastructure/src/main/java/org/springframework/batch/repeat/support/TaskExecutorRepeatTemplate.java b/spring-batch-infrastructure/src/main/java/org/springframework/batch/repeat/support/TaskExecutorRepeatTemplate.java index 44dba5b449..b664b64a66 100644 --- a/spring-batch-infrastructure/src/main/java/org/springframework/batch/repeat/support/TaskExecutorRepeatTemplate.java +++ b/spring-batch-infrastructure/src/main/java/org/springframework/batch/repeat/support/TaskExecutorRepeatTemplate.java @@ -16,6 +16,8 @@ package org.springframework.batch.repeat.support; +import java.util.concurrent.CountDownLatch; + import org.springframework.batch.repeat.RepeatCallback; import org.springframework.batch.repeat.RepeatContext; import org.springframework.batch.repeat.RepeatException; @@ -59,6 +61,12 @@ public class TaskExecutorRepeatTemplate extends RepeatTemplate { private TaskExecutor taskExecutor = new SyncTaskExecutor(); + /** + * A latch to ensure to manage the first chunk by the the first thread. This is + * specifically required to manage data with record separators like JSON. + */ + private final CountDownLatch latch = new CountDownLatch(1); + /** * Public setter for the throttle limit. The throttle limit is the largest number of * concurrent tasks that can be executing at one time - if a new task arrives and the @@ -110,7 +118,7 @@ protected RepeatStatus getNextResult(RepeatContext context, RepeatCallback callb * Wrap the callback in a runnable that will add its result to the queue when * it is ready. */ - runnable = new ExecutingRunnable(callback, context, queue); + runnable = new ExecutingRunnable(callback, context, queue, latch); /* * Tell the runnable that it can expect a result. This could have been @@ -130,6 +138,13 @@ protected RepeatStatus getNextResult(RepeatContext context, RepeatCallback callb */ update(context); + /* + * Wait for the first chunk to be managed before to create other threads. This + * will ensure to correctly write first data chunk with record separators like + * JSON. + */ + latch.await(); + /* * Keep going until we get a result that is finished, or early termination... */ @@ -216,14 +231,17 @@ private class ExecutingRunnable implements Runnable, ResultHolder { private volatile Throwable error; - public ExecutingRunnable(RepeatCallback callback, RepeatContext context, ResultQueue queue) { + private CountDownLatch latch; + + public ExecutingRunnable(RepeatCallback callback, RepeatContext context, ResultQueue queue, + CountDownLatch latch) { super(); this.callback = callback; this.context = context; this.queue = queue; - + this.latch = latch; } /** @@ -272,6 +290,11 @@ public void run() { queue.put(this); + /* + * If this is the first chunk, then release the latch so that other + * threads can be created. + */ + this.latch.countDown(); } } diff --git a/spring-batch-infrastructure/src/test/java/org/springframework/batch/repeat/support/AbstractTradeBatchTests.java b/spring-batch-infrastructure/src/test/java/org/springframework/batch/repeat/support/AbstractTradeBatchTests.java index a95e9fbed7..3d19917966 100644 --- a/spring-batch-infrastructure/src/test/java/org/springframework/batch/repeat/support/AbstractTradeBatchTests.java +++ b/spring-batch-infrastructure/src/test/java/org/springframework/batch/repeat/support/AbstractTradeBatchTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2006-2023 the original author or authors. + * Copyright 2006-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.batch.repeat.support; +import java.util.ArrayList; + import org.junit.jupiter.api.BeforeEach; import org.springframework.batch.item.Chunk; @@ -42,12 +44,16 @@ abstract class AbstractTradeBatchTests { Resource resource = new ClassPathResource("trades.csv", getClass()); - protected TradeWriter processor = new TradeWriter(); + protected TradeWriter processor; protected TradeItemReader provider; + protected ArrayList output; + @BeforeEach void setUp() throws Exception { + output = new ArrayList<>(); + processor = new TradeWriter(output); provider = new TradeItemReader(resource); provider.open(new ExecutionContext()); } @@ -79,10 +85,17 @@ protected static class TradeWriter implements ItemWriter { int count = 0; + private ArrayList out; + + public TradeWriter(ArrayList out) { + this.out = out; + } + // This has to be synchronized because we are going to test the state // (count) at the end of a concurrent batch run. @Override public synchronized void write(Chunk data) { + out.addAll(data.getItems()); count++; } diff --git a/spring-batch-infrastructure/src/test/java/org/springframework/batch/repeat/support/TaskExecutorRepeatTemplateFirstChunkTests.java b/spring-batch-infrastructure/src/test/java/org/springframework/batch/repeat/support/TaskExecutorRepeatTemplateFirstChunkTests.java new file mode 100644 index 0000000000..d369689971 --- /dev/null +++ b/spring-batch-infrastructure/src/test/java/org/springframework/batch/repeat/support/TaskExecutorRepeatTemplateFirstChunkTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.batch.repeat.support; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; + +import org.springframework.batch.repeat.policy.SimpleCompletionPolicy; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; + +/** + * Tests for concurrent behaviour in repeat template, dedicated to the first chunk, that + * must be managed first when output format has separator between items, like JSON. + * + * @author Gerald Lelarge + * + */ +class TaskExecutorRepeatTemplateFirstChunkTests extends AbstractTradeBatchTests { + + private TaskExecutorRepeatTemplate template; + + private int chunkSize = 5; + + private final ThreadPoolTaskExecutor threadPool = new ThreadPoolTaskExecutor(); + + @BeforeEach + void setUp() throws Exception { + + super.setUp(); + + threadPool.setMaxPoolSize(10); + threadPool.setCorePoolSize(10); + threadPool.setQueueCapacity(0); + threadPool.afterPropertiesSet(); + + template = new TaskExecutorRepeatTemplate(); + template.setTaskExecutor(threadPool); + // Limit the number of threads to 2 + template.setThrottleLimit(2); + // Limit the number of items to read to be able to test the second item from the + // output. If the chunkSize is greater than 2, the test could fail. + template.setCompletionPolicy(new SimpleCompletionPolicy(chunkSize)); + } + + @AfterEach + void tearDown() { + threadPool.destroy(); + } + + /** + * Test method for {@link TaskExecutorRepeatTemplate#iterate(RepeatCallback)}. Repeat + * the tests 20 times to increase the probability of detecting a concurrency. + */ + @Test + @RepeatedTest(value = 20) + void testExecute() { + + // given + template.iterate(new ItemReaderRepeatCallback<>(provider, processor)); + + // then + // The first element is the first item of the input trades.csv. + assertEquals("UK21341EAH45", output.get(0).getIsin()); + // The others can have different orders. + for (int i = 1; i < output.size(); i++) { + assertNotEquals("UK21341EAH45", output.get(i).getIsin()); + } + assertEquals(chunkSize, processor.count); + } + +}