diff --git a/CHANGELOG.md b/CHANGELOG.md index 733f443c..f6e2e8e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ * Fix the potential NPE issue of `DurableTaskClient#terminate` method ([#104](https://github.com/microsoft/durabletask-java/issues/104)) * Add waitForCompletionOrCreateCheckStatusResponse client API ([#115](https://github.com/microsoft/durabletask-java/pull/115)) * Support long timers by breaking up into smaller timers ([#114](https://github.com/microsoft/durabletask-java/issues/114)) +* Support restartInstance and pass restartPostUri in HttpManagementPayload ([#108](https://github.com/microsoft/durabletask-java/issues/108)) ## v1.0.0 diff --git a/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/HttpManagementPayload.java b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/HttpManagementPayload.java index bfd15cf0..4cb45bc8 100644 --- a/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/HttpManagementPayload.java +++ b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/HttpManagementPayload.java @@ -12,6 +12,7 @@ public class HttpManagementPayload { private final String id; private final String purgeHistoryDeleteUri; + private final String restartPostUri; private final String sendEventPostUri; private final String statusQueryGetUri; private final String terminatePostUri; @@ -29,6 +30,7 @@ public HttpManagementPayload( String requiredQueryStringParameters) { this.id = instanceId; this.purgeHistoryDeleteUri = instanceStatusURL + "?" + requiredQueryStringParameters; + this.restartPostUri = instanceStatusURL + "/restart?" + requiredQueryStringParameters; this.sendEventPostUri = instanceStatusURL + "/raiseEvent/{eventName}?" + requiredQueryStringParameters; this.statusQueryGetUri = instanceStatusURL + "?" + requiredQueryStringParameters; this.terminatePostUri = instanceStatusURL + "/terminate?reason={text}&" + requiredQueryStringParameters; @@ -78,4 +80,14 @@ public String getTerminatePostUri() { public String getPurgeHistoryDeleteUri() { return this.purgeHistoryDeleteUri; } + + /** + * Gets the HTTP POST instance restart endpoint. + * + * @return The HTTP URL for posting instance restart commands. + */ + public String getRestartPostUri() { + return restartPostUri; + } + } diff --git a/client/src/main/java/com/microsoft/durabletask/DurableTaskClient.java b/client/src/main/java/com/microsoft/durabletask/DurableTaskClient.java index c914c229..14a8135e 100644 --- a/client/src/main/java/com/microsoft/durabletask/DurableTaskClient.java +++ b/client/src/main/java/com/microsoft/durabletask/DurableTaskClient.java @@ -282,6 +282,16 @@ public abstract OrchestrationMetadata waitForInstanceCompletion( */ public abstract PurgeResult purgeInstances(PurgeInstanceCriteria purgeInstanceCriteria) throws TimeoutException; + /** + * Restarts an existing orchestration instance with the original input. + * @param instanceId the ID of the previously run orchestration instance to restart. + * @param restartWithNewInstanceId true to restart the orchestration instance with a new instance ID + * false to restart the orchestration instance with same instance ID + * @return the ID of the scheduled orchestration instance, which is either instanceId or randomly + * generated depending on the value of restartWithNewInstanceId + */ + public abstract String restartInstance(String instanceId, boolean restartWithNewInstanceId); + // /** // * Suspends a running orchestration instance. // * @param instanceId the ID of the orchestration instance to suspend diff --git a/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java index 56d10804..35798989 100644 --- a/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java +++ b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java @@ -303,6 +303,24 @@ public PurgeResult purgeInstances(PurgeInstanceCriteria purgeInstanceCriteria) t // this.sidecarClient.resumeInstance(resumeRequestBuilder.build()); // } + @Override + public String restartInstance(String instanceId, boolean restartWithNewInstanceId) { + OrchestrationMetadata metadata = this.getInstanceMetadata(instanceId, true); + if (!metadata.isInstanceFound()) { + throw new IllegalArgumentException(new StringBuilder() + .append("An orchestration with instanceId ") + .append(instanceId) + .append(" was not found.").toString()); + } + + if (restartWithNewInstanceId) { + return this.scheduleNewOrchestrationInstance(metadata.getName(), this.dataConverter.deserialize(metadata.getSerializedInput(), Object.class)); + } + else { + return this.scheduleNewOrchestrationInstance(metadata.getName(), this.dataConverter.deserialize(metadata.getSerializedInput(), Object.class), metadata.getInstanceId()); + } + } + private PurgeResult toPurgeResult(PurgeInstancesResponse response){ return new PurgeResult(response.getDeletedInstanceCount()); } diff --git a/client/src/test/java/com/microsoft/durabletask/IntegrationTests.java b/client/src/test/java/com/microsoft/durabletask/IntegrationTests.java index 5d6a47f5..6edef35e 100644 --- a/client/src/test/java/com/microsoft/durabletask/IntegrationTests.java +++ b/client/src/test/java/com/microsoft/durabletask/IntegrationTests.java @@ -408,6 +408,55 @@ void termination() throws TimeoutException { } } + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void restartOrchestrationWithNewInstanceId(boolean restartWithNewInstanceId) throws TimeoutException { + final String orchestratorName = "restart"; + final Duration delay = Duration.ofSeconds(3); + + DurableTaskGrpcWorker worker = this.createWorkerBuilder() + .addOrchestrator(orchestratorName, ctx -> ctx.createTimer(delay).await()) + .buildAndStart(); + + DurableTaskClient client = new DurableTaskGrpcClientBuilder().build(); + try (worker; client) { + String instanceId = client.scheduleNewOrchestrationInstance(orchestratorName, "RestartTest"); + client.waitForInstanceCompletion(instanceId, defaultTimeout, true); + String newInstanceId = client.restartInstance(instanceId, restartWithNewInstanceId); + OrchestrationMetadata instance = client.waitForInstanceCompletion(newInstanceId, defaultTimeout, true); + + if (restartWithNewInstanceId) { + assertNotEquals(instanceId, newInstanceId); + } else { + assertEquals(instanceId, newInstanceId); + } + assertEquals(OrchestrationRuntimeStatus.COMPLETED, instance.getRuntimeStatus()); + assertEquals("\"RestartTest\"", instance.getSerializedInput()); + } + } + + @Test + void restartOrchestrationThrowsException() { + final String orchestratorName = "restart"; + final Duration delay = Duration.ofSeconds(3); + final String nonExistentId = "123"; + + DurableTaskGrpcWorker worker = this.createWorkerBuilder() + .addOrchestrator(orchestratorName, ctx -> ctx.createTimer(delay).await()) + .buildAndStart(); + + DurableTaskClient client = new DurableTaskGrpcClientBuilder().build(); + try (worker; client) { + client.scheduleNewOrchestrationInstance(orchestratorName, "RestartTest"); + + assertThrows( + IllegalArgumentException.class, + () -> client.restartInstance(nonExistentId, true) + ); + } + + } + // @Test // void suspendResumeOrchestration() throws TimeoutException, InterruptedException { // final String orchestratorName = "suspend"; diff --git a/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java b/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java index a174b39b..fa5f64df 100644 --- a/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java +++ b/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java @@ -5,6 +5,8 @@ import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import static io.restassured.RestAssured.get; import static io.restassured.RestAssured.post; @@ -26,14 +28,7 @@ public void basicChain() throws InterruptedException { Response response = post(startOrchestrationPath); JsonPath jsonPath = response.jsonPath(); String statusQueryGetUri = jsonPath.get("statusQueryGetUri"); - String runTimeStatus = null; - for (int i = 0; i < 15; i++) { - Response statusResponse = get(statusQueryGetUri); - runTimeStatus = statusResponse.jsonPath().get("runtimeStatus"); - if (!"Completed".equals(runTimeStatus)) { - Thread.sleep(1000); - } else break; - } + String runTimeStatus = waitForCompletion(statusQueryGetUri); assertEquals("Completed", runTimeStatus); } @@ -59,4 +54,44 @@ public void continueAsNew() throws InterruptedException { runTimeStatus = statusResponse.jsonPath().get("runtimeStatus"); assertEquals("Terminated", runTimeStatus); } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void restart(boolean restartWithNewInstanceId) throws InterruptedException { + String startOrchestrationPath = "/api/StartOrchestration"; + Response response = post(startOrchestrationPath); + JsonPath jsonPath = response.jsonPath(); + String statusQueryGetUri = jsonPath.get("statusQueryGetUri"); + String runTimeStatus = waitForCompletion(statusQueryGetUri); + assertEquals("Completed", runTimeStatus); + Response statusResponse = get(statusQueryGetUri); + String instanceId = statusResponse.jsonPath().get("instanceId"); + + String restartPostUri = jsonPath.get("restartPostUri") + "&restartWithNewInstanceId=" + restartWithNewInstanceId; + Response restartResponse = post(restartPostUri); + JsonPath restartJsonPath = restartResponse.jsonPath(); + String restartStatusQueryGetUri = restartJsonPath.get("statusQueryGetUri"); + String restartRuntimeStatus = waitForCompletion(restartStatusQueryGetUri); + assertEquals("Completed", restartRuntimeStatus); + Response restartStatusResponse = get(restartStatusQueryGetUri); + String newInstanceId = restartStatusResponse.jsonPath().get("instanceId"); + if (restartWithNewInstanceId) { + assertNotEquals(instanceId, newInstanceId); + } else { + assertEquals(instanceId, newInstanceId); + } + } + + private String waitForCompletion(String statusQueryGetUri) throws InterruptedException { + String runTimeStatus = null; + for (int i = 0; i < 15; i++) { + Response statusResponse = get(statusQueryGetUri); + runTimeStatus = statusResponse.jsonPath().get("runtimeStatus"); + if (!"Completed".equals(runTimeStatus)) { + Thread.sleep(1000); + } else break; + } + return runTimeStatus; + } + }