diff --git a/CHANGELOG.md b/CHANGELOG.md
index 733f443c..f6e2e8e4 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,7 @@
* Fix the potential NPE issue of `DurableTaskClient#terminate` method ([#104](https://github.com/microsoft/durabletask-java/issues/104))
* Add waitForCompletionOrCreateCheckStatusResponse client API ([#115](https://github.com/microsoft/durabletask-java/pull/115))
* Support long timers by breaking up into smaller timers ([#114](https://github.com/microsoft/durabletask-java/issues/114))
+* Support restartInstance and pass restartPostUri in HttpManagementPayload ([#108](https://github.com/microsoft/durabletask-java/issues/108))
## v1.0.0
diff --git a/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/HttpManagementPayload.java b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/HttpManagementPayload.java
index bfd15cf0..4cb45bc8 100644
--- a/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/HttpManagementPayload.java
+++ b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/HttpManagementPayload.java
@@ -12,6 +12,7 @@
public class HttpManagementPayload {
private final String id;
private final String purgeHistoryDeleteUri;
+ private final String restartPostUri;
private final String sendEventPostUri;
private final String statusQueryGetUri;
private final String terminatePostUri;
@@ -29,6 +30,7 @@ public HttpManagementPayload(
String requiredQueryStringParameters) {
this.id = instanceId;
this.purgeHistoryDeleteUri = instanceStatusURL + "?" + requiredQueryStringParameters;
+ this.restartPostUri = instanceStatusURL + "/restart?" + requiredQueryStringParameters;
this.sendEventPostUri = instanceStatusURL + "/raiseEvent/{eventName}?" + requiredQueryStringParameters;
this.statusQueryGetUri = instanceStatusURL + "?" + requiredQueryStringParameters;
this.terminatePostUri = instanceStatusURL + "/terminate?reason={text}&" + requiredQueryStringParameters;
@@ -78,4 +80,14 @@ public String getTerminatePostUri() {
public String getPurgeHistoryDeleteUri() {
return this.purgeHistoryDeleteUri;
}
+
+ /**
+ * Gets the HTTP POST instance restart endpoint.
+ *
+ * @return The HTTP URL for posting instance restart commands.
+ */
+ public String getRestartPostUri() {
+ return restartPostUri;
+ }
+
}
diff --git a/client/src/main/java/com/microsoft/durabletask/DurableTaskClient.java b/client/src/main/java/com/microsoft/durabletask/DurableTaskClient.java
index c914c229..14a8135e 100644
--- a/client/src/main/java/com/microsoft/durabletask/DurableTaskClient.java
+++ b/client/src/main/java/com/microsoft/durabletask/DurableTaskClient.java
@@ -282,6 +282,16 @@ public abstract OrchestrationMetadata waitForInstanceCompletion(
*/
public abstract PurgeResult purgeInstances(PurgeInstanceCriteria purgeInstanceCriteria) throws TimeoutException;
+ /**
+ * Restarts an existing orchestration instance with the original input.
+ * @param instanceId the ID of the previously run orchestration instance to restart.
+ * @param restartWithNewInstanceId true
to restart the orchestration instance with a new instance ID
+ * false
to restart the orchestration instance with same instance ID
+ * @return the ID of the scheduled orchestration instance, which is either instanceId
or randomly
+ * generated depending on the value of restartWithNewInstanceId
+ */
+ public abstract String restartInstance(String instanceId, boolean restartWithNewInstanceId);
+
// /**
// * Suspends a running orchestration instance.
// * @param instanceId the ID of the orchestration instance to suspend
diff --git a/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java
index 56d10804..35798989 100644
--- a/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java
+++ b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java
@@ -303,6 +303,24 @@ public PurgeResult purgeInstances(PurgeInstanceCriteria purgeInstanceCriteria) t
// this.sidecarClient.resumeInstance(resumeRequestBuilder.build());
// }
+ @Override
+ public String restartInstance(String instanceId, boolean restartWithNewInstanceId) {
+ OrchestrationMetadata metadata = this.getInstanceMetadata(instanceId, true);
+ if (!metadata.isInstanceFound()) {
+ throw new IllegalArgumentException(new StringBuilder()
+ .append("An orchestration with instanceId ")
+ .append(instanceId)
+ .append(" was not found.").toString());
+ }
+
+ if (restartWithNewInstanceId) {
+ return this.scheduleNewOrchestrationInstance(metadata.getName(), this.dataConverter.deserialize(metadata.getSerializedInput(), Object.class));
+ }
+ else {
+ return this.scheduleNewOrchestrationInstance(metadata.getName(), this.dataConverter.deserialize(metadata.getSerializedInput(), Object.class), metadata.getInstanceId());
+ }
+ }
+
private PurgeResult toPurgeResult(PurgeInstancesResponse response){
return new PurgeResult(response.getDeletedInstanceCount());
}
diff --git a/client/src/test/java/com/microsoft/durabletask/IntegrationTests.java b/client/src/test/java/com/microsoft/durabletask/IntegrationTests.java
index 5d6a47f5..6edef35e 100644
--- a/client/src/test/java/com/microsoft/durabletask/IntegrationTests.java
+++ b/client/src/test/java/com/microsoft/durabletask/IntegrationTests.java
@@ -408,6 +408,55 @@ void termination() throws TimeoutException {
}
}
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void restartOrchestrationWithNewInstanceId(boolean restartWithNewInstanceId) throws TimeoutException {
+ final String orchestratorName = "restart";
+ final Duration delay = Duration.ofSeconds(3);
+
+ DurableTaskGrpcWorker worker = this.createWorkerBuilder()
+ .addOrchestrator(orchestratorName, ctx -> ctx.createTimer(delay).await())
+ .buildAndStart();
+
+ DurableTaskClient client = new DurableTaskGrpcClientBuilder().build();
+ try (worker; client) {
+ String instanceId = client.scheduleNewOrchestrationInstance(orchestratorName, "RestartTest");
+ client.waitForInstanceCompletion(instanceId, defaultTimeout, true);
+ String newInstanceId = client.restartInstance(instanceId, restartWithNewInstanceId);
+ OrchestrationMetadata instance = client.waitForInstanceCompletion(newInstanceId, defaultTimeout, true);
+
+ if (restartWithNewInstanceId) {
+ assertNotEquals(instanceId, newInstanceId);
+ } else {
+ assertEquals(instanceId, newInstanceId);
+ }
+ assertEquals(OrchestrationRuntimeStatus.COMPLETED, instance.getRuntimeStatus());
+ assertEquals("\"RestartTest\"", instance.getSerializedInput());
+ }
+ }
+
+ @Test
+ void restartOrchestrationThrowsException() {
+ final String orchestratorName = "restart";
+ final Duration delay = Duration.ofSeconds(3);
+ final String nonExistentId = "123";
+
+ DurableTaskGrpcWorker worker = this.createWorkerBuilder()
+ .addOrchestrator(orchestratorName, ctx -> ctx.createTimer(delay).await())
+ .buildAndStart();
+
+ DurableTaskClient client = new DurableTaskGrpcClientBuilder().build();
+ try (worker; client) {
+ client.scheduleNewOrchestrationInstance(orchestratorName, "RestartTest");
+
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> client.restartInstance(nonExistentId, true)
+ );
+ }
+
+ }
+
// @Test
// void suspendResumeOrchestration() throws TimeoutException, InterruptedException {
// final String orchestratorName = "suspend";
diff --git a/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java b/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java
index a174b39b..fa5f64df 100644
--- a/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java
+++ b/samples-azure-functions/src/test/java/com/functions/EndToEndTests.java
@@ -5,6 +5,8 @@
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
import static io.restassured.RestAssured.get;
import static io.restassured.RestAssured.post;
@@ -26,14 +28,7 @@ public void basicChain() throws InterruptedException {
Response response = post(startOrchestrationPath);
JsonPath jsonPath = response.jsonPath();
String statusQueryGetUri = jsonPath.get("statusQueryGetUri");
- String runTimeStatus = null;
- for (int i = 0; i < 15; i++) {
- Response statusResponse = get(statusQueryGetUri);
- runTimeStatus = statusResponse.jsonPath().get("runtimeStatus");
- if (!"Completed".equals(runTimeStatus)) {
- Thread.sleep(1000);
- } else break;
- }
+ String runTimeStatus = waitForCompletion(statusQueryGetUri);
assertEquals("Completed", runTimeStatus);
}
@@ -59,4 +54,44 @@ public void continueAsNew() throws InterruptedException {
runTimeStatus = statusResponse.jsonPath().get("runtimeStatus");
assertEquals("Terminated", runTimeStatus);
}
+
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ public void restart(boolean restartWithNewInstanceId) throws InterruptedException {
+ String startOrchestrationPath = "/api/StartOrchestration";
+ Response response = post(startOrchestrationPath);
+ JsonPath jsonPath = response.jsonPath();
+ String statusQueryGetUri = jsonPath.get("statusQueryGetUri");
+ String runTimeStatus = waitForCompletion(statusQueryGetUri);
+ assertEquals("Completed", runTimeStatus);
+ Response statusResponse = get(statusQueryGetUri);
+ String instanceId = statusResponse.jsonPath().get("instanceId");
+
+ String restartPostUri = jsonPath.get("restartPostUri") + "&restartWithNewInstanceId=" + restartWithNewInstanceId;
+ Response restartResponse = post(restartPostUri);
+ JsonPath restartJsonPath = restartResponse.jsonPath();
+ String restartStatusQueryGetUri = restartJsonPath.get("statusQueryGetUri");
+ String restartRuntimeStatus = waitForCompletion(restartStatusQueryGetUri);
+ assertEquals("Completed", restartRuntimeStatus);
+ Response restartStatusResponse = get(restartStatusQueryGetUri);
+ String newInstanceId = restartStatusResponse.jsonPath().get("instanceId");
+ if (restartWithNewInstanceId) {
+ assertNotEquals(instanceId, newInstanceId);
+ } else {
+ assertEquals(instanceId, newInstanceId);
+ }
+ }
+
+ private String waitForCompletion(String statusQueryGetUri) throws InterruptedException {
+ String runTimeStatus = null;
+ for (int i = 0; i < 15; i++) {
+ Response statusResponse = get(statusQueryGetUri);
+ runTimeStatus = statusResponse.jsonPath().get("runtimeStatus");
+ if (!"Completed".equals(runTimeStatus)) {
+ Thread.sleep(1000);
+ } else break;
+ }
+ return runTimeStatus;
+ }
+
}