From 36b50934001f4f2e43bed6953330d8b7b3a67817 Mon Sep 17 00:00:00 2001 From: "mend-for-github-com[bot]" <50673670+mend-for-github-com[bot]@users.noreply.github.com> Date: Thu, 16 Jan 2025 12:00:07 -0800 Subject: [PATCH 1/3] fix(deps): update swaggercoreversion to v2.2.28 (#1005) Signed-off-by: mend-for-github-com[bot] Co-authored-by: mend-for-github-com[bot] <50673670+mend-for-github-com[bot]@users.noreply.github.com> --- build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index 68f904e3..144683cc 100644 --- a/build.gradle +++ b/build.gradle @@ -36,7 +36,7 @@ buildscript { isSameMajorVersion = opensearch_version.split("\\.")[0] == bwcVersionShort.split("\\.")[0] swaggerVersion = "2.1.24" jacksonVersion = "2.18.1" - swaggerCoreVersion = "2.2.27" + swaggerCoreVersion = "2.2.28" } From 06b78310a1d786e1c69f16751177ecb843abf6fe Mon Sep 17 00:00:00 2001 From: "mend-for-github-com[bot]" <50673670+mend-for-github-com[bot]@users.noreply.github.com> Date: Thu, 16 Jan 2025 12:02:44 -0800 Subject: [PATCH 2/3] chore(deps): update dependency gradle to v8.12 (#984) Signed-off-by: mend-for-github-com[bot] Co-authored-by: mend-for-github-com[bot] <50673670+mend-for-github-com[bot]@users.noreply.github.com> Co-authored-by: Amit Galitzky --- gradle/wrapper/gradle-wrapper.properties | 4 ++-- gradlew | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index eb1a55be..e1b837a1 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=f397b287023acdba1e9f6fc5ea72d22dd63669d59ed4a289a29b1a76eee151c6 -distributionUrl=https\://services.gradle.org/distributions/gradle-8.11.1-bin.zip +distributionSha256Sum=7a00d51fb93147819aab76024feece20b6b84e420694101f276be952e08bef03 +distributionUrl=https\://services.gradle.org/distributions/gradle-8.12-bin.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/gradlew b/gradlew index f5feea6d..f3b75f3b 100755 --- a/gradlew +++ b/gradlew @@ -86,8 +86,7 @@ done # shellcheck disable=SC2034 APP_BASE_NAME=${0##*/} # Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) -APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s -' "$PWD" ) || exit +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s\n' "$PWD" ) || exit # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD=maximum From 33579a35a9d7a72cc0a2d44561b98bd64a79dd04 Mon Sep 17 00:00:00 2001 From: Junwei Dai <59641585+junweid62@users.noreply.github.com> Date: Thu, 16 Jan 2025 15:36:14 -0800 Subject: [PATCH 3/3] Add synchronous execution option to workflow provisioning (#990) * Add synchronous execution option to workflow provisioning Signed-off-by: Junwei Dai * code refactor Signed-off-by: Junwei Dai * add change log Signed-off-by: Junwei Dai * refactor code based on comment Signed-off-by: Junwei Dai * fix spotless check Signed-off-by: Junwei Dai * Limit workflow timeout to a range of 1 to 300 seconds Signed-off-by: Junwei Dai * Limit workflow timeout to a range of 1 to 300 seconds Signed-off-by: Junwei Dai # Conflicts: # src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java * Limit workflow timeout to non-negative Signed-off-by: Junwei Dai * Add synchronous execution to reprovision Signed-off-by: Junwei Dai * remove unsued common value Signed-off-by: Junwei Dai * add reprovision sync execution Signed-off-by: Junwei Dai # Conflicts: # src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java * fix test for WorkflowTimeoutUtilityTests Signed-off-by: Junwei Dai * fix test name for WorkflowTimeoutUtilityTests Signed-off-by: Junwei Dai * Add comments to explain AtomicBoolean usage in WorkflowTimeoutUtility, update error message Signed-off-by: Junwei Dai * fix spotless check Signed-off-by: Junwei Dai * addressed some comments Signed-off-by: Junwei Dai --------- Signed-off-by: Junwei Dai Co-authored-by: Junwei Dai --- CHANGELOG.md | 2 + .../flowframework/common/CommonValue.java | 4 + .../rest/RestCreateWorkflowAction.java | 16 +- .../rest/RestProvisionWorkflowAction.java | 5 +- .../CreateWorkflowTransportAction.java | 36 ++- .../ProvisionWorkflowTransportAction.java | 88 ++++++- .../transport/ReprovisionWorkflowRequest.java | 30 ++- .../ReprovisionWorkflowTransportAction.java | 81 ++++++- .../transport/WorkflowRequest.java | 76 +++++- .../transport/WorkflowResponse.java | 46 +++- .../util/WorkflowTimeoutUtility.java | 202 +++++++++++++++ .../rest/RestCreateWorkflowActionTests.java | 38 +++ .../RestProvisionWorkflowActionTests.java | 22 ++ .../CreateWorkflowTransportActionTests.java | 229 +++++++++++++++++- .../ReprovisionWorkflowRequestTests.java | 3 +- ...provisionWorkflowTransportActionTests.java | 11 +- .../WorkflowRequestResponseTests.java | 40 ++- .../util/WorkflowTimeoutUtilityTests.java | 136 +++++++++++ 18 files changed, 1026 insertions(+), 39 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java create mode 100644 src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index c8f99f0b..cd347365 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.18...2.x) ### Features +- Add synchronous execution option to workflow provisioning ([#990](https://github.com/opensearch-project/flow-framework/pull/990)) + ### Enhancements ### Bug Fixes - Remove useCase and defaultParams field in WorkflowRequest ([#758](https://github.com/opensearch-project/flow-framework/pull/758)) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 9c88788b..0a4af075 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -74,6 +74,8 @@ private CommonValue() {} public static final String PROVISION_WORKFLOW = "provision"; /** The param name for update workflow field in create API */ public static final String UPDATE_WORKFLOW_FIELDS = "update_fields"; + /** The param name for specifying the timeout duration in seconds to wait for workflow completion */ + public static final String WAIT_FOR_COMPLETION_TIMEOUT = "wait_for_completion_timeout"; /** The field name for workflow steps. This field represents the name of the workflow steps to be fetched. */ public static final String WORKFLOW_STEP = "workflow_step"; /** The param name for default use case, used by the create workflow API */ @@ -186,6 +188,8 @@ private CommonValue() {} public static final String SOURCE_INDEX = "source_index"; /** The destination index field for reindex */ public static final String DESTINATION_INDEX = "destination_index"; + /** Provision Timeout field */ + public static final String PROVISION_TIMEOUT_FIELD = "provision.timeout"; /* * Constants associated with resource provisioning / state */ diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 4abedc36..b106b05f 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.client.node.NodeClient; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; @@ -43,6 +44,7 @@ import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; import static org.opensearch.flowframework.common.CommonValue.USE_CASE; import static org.opensearch.flowframework.common.CommonValue.VALIDATION; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -88,6 +90,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli boolean reprovision = request.paramAsBoolean(REPROVISION_WORKFLOW, false); boolean updateFields = request.paramAsBoolean(UPDATE_WORKFLOW_FIELDS, false); String useCase = request.param(USE_CASE); + TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, TimeValue.MINUS_ONE); // If provisioning, consume all other params and pass to provision transport action Map params = provision @@ -145,6 +148,15 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ); return processError(ffe, params, request); } + // Ensure wait_for_completion is not set unless reprovision or provision is true + if (waitForCompletionTimeout != TimeValue.MINUS_ONE && !(reprovision || provision)) { + FlowFrameworkException ffe = new FlowFrameworkException( + "Request parameters 'wait_for_completion_timeout' are not allowed unless the 'provision' or 'reprovision' parameter is set to true.", + RestStatus.BAD_REQUEST + ); + return processError(ffe, params, request); + } + try { Template template; Map useCaseDefaultsMap = Collections.emptyMap(); @@ -219,7 +231,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli if (updateFields) { params = Map.of(UPDATE_WORKFLOW_FIELDS, "true"); } - + if (waitForCompletionTimeout != TimeValue.MINUS_ONE) { + params = Map.of(WAIT_FOR_COMPLETION_TIMEOUT, waitForCompletionTimeout.toString()); + } WorkflowRequest workflowRequest = new WorkflowRequest( workflowId, template, diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 6ae56905..e197312e 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.client.node.NodeClient; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; @@ -33,6 +34,7 @@ import java.util.stream.Collectors; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -73,6 +75,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String workflowId = request.param(WORKFLOW_ID); + TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, TimeValue.MINUS_ONE); try { Map params = parseParamsAndContent(request); if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { @@ -86,7 +89,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } // Create request and provision - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, params); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, params, waitForCompletionTimeout); return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 813613a3..ff1df88b 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -53,6 +53,7 @@ import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; import static org.opensearch.flowframework.util.ParseUtils.checkFilterByBackendRoles; import static org.opensearch.flowframework.util.ParseUtils.getUserContext; @@ -214,6 +215,16 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { - listener.onResponse(new WorkflowResponse(provisionResponse.getWorkflowId())); + listener.onResponse( + (workflowRequest.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) + ? new WorkflowResponse(provisionResponse.getWorkflowId()) + : new WorkflowResponse( + provisionResponse.getWorkflowId(), + provisionResponse.getWorkflowState() + ) + ); }, exception -> { String errorMessage = "Provisioning failed."; logger.error(errorMessage, exception); @@ -346,19 +365,26 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { - listener.onResponse(new WorkflowResponse(reprovisionResponse.getWorkflowId())); + listener.onResponse( + reprovisionRequest.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE + ? new WorkflowResponse(reprovisionResponse.getWorkflowId()) + : new WorkflowResponse( + reprovisionResponse.getWorkflowId(), + reprovisionResponse.getWorkflowState() + ) + ); }, exception -> { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( "Reprovisioning failed for workflow {}", diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 45f37416..bfe9aee0 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -20,6 +20,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -32,6 +33,7 @@ import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.util.WorkflowTimeoutUtility; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.plugins.PluginsService; @@ -45,6 +47,8 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.ERROR_FIELD; @@ -210,14 +214,27 @@ private void executeProvisionRequest( ), ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", request.getWorkflowId(), State.PROVISIONING); - executeWorkflowAsync(workflowId, provisionProcessSequence, listener); + if (request.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) { + executeWorkflowAsync(workflowId, provisionProcessSequence, listener); + } else { + executeWorkflowSync( + workflowId, + provisionProcessSequence, + listener, + request.getWaitForCompletionTimeout().getMillis() + ); + } // update last provisioned field in template Template newTemplate = Template.builder(template).lastProvisionedTime(Instant.now()).build(); flowFrameworkIndicesHandler.updateTemplateInGlobalContext( request.getWorkflowId(), newTemplate, ActionListener.wrap(templateResponse -> { - listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + if (request.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) { + listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + } else { + logger.info("Waiting for workflow completion"); + } }, exception -> { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( "Failed to update use case template {}", @@ -275,18 +292,64 @@ private void executeProvisionRequest( */ private void executeWorkflowAsync(String workflowId, List workflowSequence, ActionListener listener) { try { - threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, workflowId); }); + threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL) + .execute(() -> { executeWorkflow(workflowSequence, workflowId, listener, false); }); } catch (Exception exception) { listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(exception))); } } + /** + * Retrieves a thread from the provision thread pool to execute a workflow with a timeout mechanism. + * If the execution exceeds the specified timeout, it will return the current status of the workflow. + * + * @param workflowId The id of the workflow + * @param workflowSequence The sorted workflow to execute + * @param listener ActionListener for any failures or responses + * @param timeout The timeout duration in milliseconds + */ + private void executeWorkflowSync( + String workflowId, + List workflowSequence, + ActionListener listener, + long timeout + ) { + AtomicBoolean isResponseSent = new AtomicBoolean(false); + + CompletableFuture.runAsync(() -> { + try { + executeWorkflow(workflowSequence, workflowId, new ActionListener<>() { + @Override + public void onResponse(WorkflowResponse workflowResponse) { + WorkflowTimeoutUtility.handleResponse(workflowId, workflowResponse, isResponseSent, listener); + } + + @Override + public void onFailure(Exception e) { + WorkflowTimeoutUtility.handleFailure(workflowId, e, isResponseSent, listener); + } + }, true); + } catch (Exception ex) { + WorkflowTimeoutUtility.handleFailure(workflowId, ex, isResponseSent, listener); + } + }, threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL)); + + WorkflowTimeoutUtility.scheduleTimeoutHandler(client, threadPool, workflowId, listener, timeout, isResponseSent); + } + /** * Executes the given workflow sequence * @param workflowSequence The topologically sorted workflow to execute * @param workflowId The workflowId associated with the workflow that is executing + * @param listener The ActionListener to handle the workflow response or failure + * @param isSyncExecution Flag indicating whether the workflow should be executed synchronously (true) or asynchronously (false) */ - private void executeWorkflow(List workflowSequence, String workflowId) { + private void executeWorkflow( + List workflowSequence, + String workflowId, + ActionListener listener, + boolean isSyncExecution + ) { String currentStepId = ""; try { Map> workflowFutureMap = new LinkedHashMap<>(); @@ -324,6 +387,23 @@ private void executeWorkflow(List workflowSequence, String workflow ), ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); + if (isSyncExecution) { + client.execute( + GetWorkflowStateAction.INSTANCE, + new GetWorkflowStateRequest(workflowId, false), + ActionListener.wrap(response -> { + listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())); + }, exception -> { + String errorMessage = "Failed to get workflow state."; + logger.error(errorMessage, exception); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }) + ); + } }, exception -> { logger.error("Failed to update workflow state for workflow {}", workflowId, exception); }) ); } catch (Exception ex) { diff --git a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequest.java index f6cde633..e8760fdf 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequest.java @@ -8,8 +8,10 @@ */ package org.opensearch.flowframework.transport; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.flowframework.model.Template; @@ -34,16 +36,28 @@ public class ReprovisionWorkflowRequest extends ActionRequest { */ private Template updatedTemplate; + /** + * The timeout value for waiting for completion + */ + private TimeValue waitForCompletionTimeout; + /** * Instantiates a new ReprovisionWorkflowRequest * @param workflowId the workflow ID * @param originalTemplate the original Template * @param updatedTemplate the updated Template + * @param waitForCompletionTimeout the maximum duration to wait for the workflow execution to complete. */ - public ReprovisionWorkflowRequest(String workflowId, Template originalTemplate, Template updatedTemplate) { + public ReprovisionWorkflowRequest( + String workflowId, + Template originalTemplate, + Template updatedTemplate, + TimeValue waitForCompletionTimeout + ) { this.workflowId = workflowId; this.originalTemplate = originalTemplate; this.updatedTemplate = updatedTemplate; + this.waitForCompletionTimeout = waitForCompletionTimeout; } /** @@ -56,6 +70,10 @@ public ReprovisionWorkflowRequest(StreamInput in) throws IOException { this.workflowId = in.readString(); this.originalTemplate = Template.parse(in.readString()); this.updatedTemplate = Template.parse(in.readString()); + if (in.getVersion().onOrAfter(Version.V_2_19_0)) { + this.waitForCompletionTimeout = in.readTimeValue(); + } + } @Override @@ -64,6 +82,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(workflowId); out.writeString(originalTemplate.toJson()); out.writeString(updatedTemplate.toJson()); + if (out.getVersion().onOrAfter(Version.V_2_19_0)) { + out.writeTimeValue(waitForCompletionTimeout); + } } @Override @@ -95,4 +116,11 @@ public Template getUpdatedTemplate() { return this.updatedTemplate; } + /** + * Gets the waitForCompletion timeout value + * @return the timeout value + */ + public TimeValue getWaitForCompletionTimeout() { + return this.waitForCompletionTimeout; + } } diff --git a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java index 8e501228..9c9681dc 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java @@ -19,6 +19,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -34,6 +35,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.util.WorkflowTimeoutUtility; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; @@ -48,6 +50,8 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.ERROR_FIELD; @@ -243,9 +247,23 @@ private void executeReprovisionRequest( Template updatedTemplateWithProvisionedTime = Template.builder(updatedTemplate) .lastProvisionedTime(Instant.now()) .build(); - executeWorkflowAsync(workflowId, updatedTemplateWithProvisionedTime, reprovisionProcessSequence, listener); - - listener.onResponse(new WorkflowResponse(workflowId)); + if (request.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) { + executeWorkflowAsync(workflowId, updatedTemplateWithProvisionedTime, reprovisionProcessSequence, listener); + } else { + executeWorkflowSync( + workflowId, + updatedTemplate, + reprovisionProcessSequence, + listener, + request.getWaitForCompletionTimeout().getMillis() + ); + } + + if (request.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) { + listener.onResponse(new WorkflowResponse(workflowId)); + } else { + logger.info("Waiting for workflow completion"); + } }, exception -> { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Failed to update workflow state: {}", workflowId) @@ -284,13 +302,42 @@ private void executeWorkflowAsync( try { threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { updateTemplate(template, workflowId); - executeWorkflow(template, workflowSequence, workflowId); + executeWorkflow(template, workflowSequence, workflowId, listener, false); }); } catch (Exception exception) { listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(exception))); } } + private void executeWorkflowSync( + String workflowId, + Template template, + List workflowSequence, + ActionListener listener, + long timeout + ) { + AtomicBoolean isResponseSent = new AtomicBoolean(false); + CompletableFuture.runAsync(() -> { + try { + updateTemplate(template, workflowId); + executeWorkflow(template, workflowSequence, workflowId, new ActionListener<>() { + @Override + public void onResponse(WorkflowResponse workflowResponse) { + WorkflowTimeoutUtility.handleResponse(workflowId, workflowResponse, isResponseSent, listener); + } + + @Override + public void onFailure(Exception e) { + WorkflowTimeoutUtility.handleFailure(workflowId, e, isResponseSent, listener); + } + }, true); + } catch (Exception ex) { + WorkflowTimeoutUtility.handleFailure(workflowId, ex, isResponseSent, listener); + } + }, threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL)); + WorkflowTimeoutUtility.scheduleTimeoutHandler(client, threadPool, workflowId, listener, timeout, isResponseSent); + } + /** * Replace template document * @param template The template to store after reprovisioning completes successfully @@ -310,7 +357,13 @@ private void updateTemplate(Template template, String workflowId) { * @param workflowSequence The topologically sorted workflow to execute * @param workflowId The workflowId associated with the workflow that is executing */ - private void executeWorkflow(Template template, List workflowSequence, String workflowId) { + private void executeWorkflow( + Template template, + List workflowSequence, + String workflowId, + ActionListener listener, + boolean isSyncExecution + ) { String currentStepId = ""; try { Map> workflowFutureMap = new LinkedHashMap<>(); @@ -349,7 +402,23 @@ private void executeWorkflow(Template template, List workflowSequen ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); - + if (isSyncExecution) { + client.execute( + GetWorkflowStateAction.INSTANCE, + new GetWorkflowStateRequest(workflowId, false), + ActionListener.wrap(response -> { + listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())); + }, exception -> { + String errorMessage = "Failed to get workflow state."; + logger.error(errorMessage, exception); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }) + ); + } }, exception -> { logger.error("Failed to update workflow state for workflow {}", workflowId, exception); }) ); } catch (Exception ex) { diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 97f032e3..fe5912c2 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -8,9 +8,11 @@ */ package org.opensearch.flowframework.transport; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.Nullable; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.flowframework.model.Template; @@ -21,6 +23,7 @@ import static org.opensearch.flowframework.common.CommonValue.REPROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; /** * Transport Request to create, provision, and deprovision a workflow @@ -62,13 +65,20 @@ public class WorkflowRequest extends ActionRequest { */ private Map params; + /** + * The timeout duration to wait for workflow completion. + * default set to -1, the request will respond immediately with the workflowId, + * indicating asynchronous execution. + */ + private TimeValue waitForCompletionTimeout; + /** * Instantiates a new WorkflowRequest, set validation to all, no provisioning * @param workflowId the documentId of the workflow * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { - this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), false); + this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), false, TimeValue.MINUS_ONE); } /** @@ -78,7 +88,27 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) * @param params The parameters from the REST path */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, Map params) { - this(workflowId, template, new String[] { "all" }, true, params, false); + this(workflowId, template, new String[] { "all" }, true, params, false, TimeValue.MINUS_ONE); + } + + /** + * Instantiates a new WorkflowRequest with a specified wait-for-completion timeout. + * This constructor allows the caller to specify a custom timeout for the workflow execution, + * which determines how long the system will wait for the workflow to complete before returning a response. + * By default, the validation is set to "all", and provisioning is set to true. + * @param workflowId The unique document ID of the workflow. Can be null for new workflows. + * @param template The use case template that defines the structure and logic of the workflow. Can be null if not provided. + * @param params A map of parameters extracted from the REST request path, used to customize the workflow execution. + * @param waitForCompletionTimeout The maximum duration to wait for the workflow execution to complete. + * If the workflow does not complete within this timeout, the request will return a timeout response. + */ + public WorkflowRequest( + @Nullable String workflowId, + @Nullable Template template, + Map params, + TimeValue waitForCompletionTimeout + ) { + this(workflowId, template, new String[] { "all" }, true, params, false, waitForCompletionTimeout); } /** @@ -97,17 +127,41 @@ public WorkflowRequest( boolean provisionOrUpdate, Map params, boolean reprovision + ) { + this(workflowId, template, validation, provisionOrUpdate, params, reprovision, TimeValue.MINUS_ONE); + } + + /** + * Instantiates a new WorkflowRequest + * @param workflowId the documentId of the workflow + * @param template the use case template which describes the workflow + * @param validation flag to indicate if validation is necessary + * @param provisionOrUpdate provision or updateFields flag. Only one may be true, the presence of update_fields key in map indicates if updating fields, otherwise true means it's provisioning. + * @param params map of REST path params. If provisionOrUpdate is false, must be an empty map. If update_fields key is present, must be only key. + * @param reprovision flag to indicate if request is to reprovision + * @param waitForCompletionTimeout the timeout duration to wait for workflow completion + */ + public WorkflowRequest( + @Nullable String workflowId, + @Nullable Template template, + String[] validation, + boolean provisionOrUpdate, + Map params, + boolean reprovision, + TimeValue waitForCompletionTimeout ) { this.workflowId = workflowId; this.template = template; this.validation = validation; this.provision = provisionOrUpdate && !params.containsKey(UPDATE_WORKFLOW_FIELDS); this.updateFields = !provision && Boolean.parseBoolean(params.get(UPDATE_WORKFLOW_FIELDS)); - if (!this.provision && params.keySet().stream().anyMatch(k -> !UPDATE_WORKFLOW_FIELDS.equals(k))) { + if (!this.provision + && params.keySet().stream().anyMatch(k -> !UPDATE_WORKFLOW_FIELDS.equals(k) && !WAIT_FOR_COMPLETION_TIMEOUT.equals(k))) { throw new IllegalArgumentException("Params may only be included when provisioning."); } this.params = this.updateFields ? Collections.emptyMap() : params; this.reprovision = reprovision; + this.waitForCompletionTimeout = waitForCompletionTimeout; } /** @@ -133,6 +187,10 @@ public WorkflowRequest(StreamInput in) throws IOException { this.params = Collections.emptyMap(); } this.reprovision = !provision && Boolean.parseBoolean(params.get(REPROVISION_WORKFLOW)); + if (in.getVersion().onOrAfter(Version.V_2_19_0)) { + this.waitForCompletionTimeout = in.readOptionalTimeValue(); + } + } /** @@ -193,6 +251,15 @@ public boolean isReprovision() { return this.reprovision; } + /** + * Gets the timeout duration (in milliseconds) to wait for workflow completion. + * @return the timeout duration, or null if the request should return immediately + */ + @Nullable + public TimeValue getWaitForCompletionTimeout() { + return this.waitForCompletionTimeout; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -207,6 +274,9 @@ public void writeTo(StreamOutput out) throws IOException { } else if (reprovision) { out.writeMap(Map.of(REPROVISION_WORKFLOW, "true"), StreamOutput::writeString, StreamOutput::writeString); } + if (out.getVersion().onOrAfter(Version.V_2_19_0)) { + out.writeOptionalTimeValue(waitForCompletionTimeout); + } } @Override diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java index 20a7700a..20d7f475 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java @@ -8,11 +8,14 @@ */ package org.opensearch.flowframework.transport; +import org.opensearch.Version; +import org.opensearch.common.Nullable; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.model.WorkflowState; import java.io.IOException; @@ -27,6 +30,8 @@ public class WorkflowResponse extends ActionResponse implements ToXContentObject * The documentId of the workflow entry within the Global Context index */ private String workflowId; + /** The workflow state */ + private WorkflowState workflowState; /** * Instantiates a new WorkflowResponse from params @@ -44,6 +49,10 @@ public WorkflowResponse(String workflowId) { public WorkflowResponse(StreamInput in) throws IOException { super(in); this.workflowId = in.readString(); + if (in.getVersion().onOrAfter(Version.V_2_19_0)) { + this.workflowState = in.readOptionalWriteable(WorkflowState::new); + } + } /** @@ -54,14 +63,49 @@ public String getWorkflowId() { return this.workflowId; } + /** + * Gets the workflowState of this repsonse + * @return the workflowState + */ + @Nullable + public WorkflowState getWorkflowState() { + return this.workflowState; + } + + /** + * Constructs a new WorkflowResponse object with the specified workflowId and workflowState. + * The WorkflowResponse is typically returned as part of a `wait_for_completion` request, + * indicating the final state of a workflow after execution. + * @param workflowId The unique identifier for the workflow. + * @param workflowState The current state of the workflow, including status, errors (if any), + * and resources created as part of the workflow execution. + */ + public WorkflowResponse(String workflowId, WorkflowState workflowState) { + this.workflowId = workflowId; + this.workflowState = WorkflowState.builder() + .workflowId(workflowId) + .error(workflowState.getError()) + .state(workflowState.getState()) + .resourcesCreated(workflowState.resourcesCreated()) + .build(); + + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(workflowId); + if (out.getVersion().onOrAfter(Version.V_2_19_0)) { + out.writeOptionalWriteable(workflowState); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.startObject().field(WORKFLOW_ID, this.workflowId).endObject(); + if (workflowState != null) { + return workflowState.toXContent(builder, params); + } else { + return builder.startObject().field(WORKFLOW_ID, this.workflowId).endObject(); + } } } diff --git a/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java new file mode 100644 index 00000000..cbed72b3 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/util/WorkflowTimeoutUtility.java @@ -0,0 +1,202 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.GetWorkflowStateAction; +import org.opensearch.flowframework.transport.GetWorkflowStateRequest; +import org.opensearch.flowframework.transport.WorkflowResponse; +import org.opensearch.threadpool.Scheduler; +import org.opensearch.threadpool.ThreadPool; + +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; + +/** + * Utility class for managing timeout tasks in workflow execution. + * This class provides methods to schedule timeout handlers, wrap listeners with timeout cancellation logic, + * and fetch workflow states after timeouts. + */ +public class WorkflowTimeoutUtility { + + private static final Logger logger = LogManager.getLogger(WorkflowTimeoutUtility.class); + + /** + * Schedules a timeout task for a workflow execution. + * + * @param client The OpenSearch client used to interact with the cluster. + * @param threadPool The thread pool to schedule the timeout task. + * @param workflowId The unique identifier of the workflow being executed. + * @param listener The listener to notify when the task completes or times out. + * @param timeout The timeout duration in milliseconds. + * @param isResponseSent An atomic boolean to ensure the response is sent only once. + * @return A wrapped ActionListener with timeout cancellation logic. + */ + public static ActionListener scheduleTimeoutHandler( + Client client, + ThreadPool threadPool, + final String workflowId, + ActionListener listener, + long timeout, + AtomicBoolean isResponseSent + ) { + // Ensure timeout is within the valid range (non-negative) + long adjustedTimeout = Math.max(timeout, TimeValue.timeValueMillis(0).millis()); + Scheduler.ScheduledCancellable scheduledCancellable = threadPool.schedule( + new WorkflowTimeoutListener(client, workflowId, listener, isResponseSent), + TimeValue.timeValueMillis(adjustedTimeout), + PROVISION_WORKFLOW_THREAD_POOL + ); + + return wrapWithTimeoutCancellationListener(listener, scheduledCancellable, isResponseSent); + } + + /** + * A listener that handles timeout for a workflow execution. + */ + private static class WorkflowTimeoutListener implements Runnable { + private final Client client; + private final String workflowId; + private final ActionListener listener; + private final AtomicBoolean isResponseSent; + + WorkflowTimeoutListener(Client client, String workflowId, ActionListener listener, AtomicBoolean isResponseSent) { + this.client = client; + this.workflowId = workflowId; + this.listener = listener; + this.isResponseSent = isResponseSent; + } + + @Override + public void run() { + // This AtomicBoolean ensures that the timeout logic is executed only once, preventing duplicate responses. + if (isResponseSent.compareAndSet(false, true)) { + logger.warn("Workflow execution timed out for workflowId: {}", workflowId); + fetchWorkflowStateAfterTimeout(client, workflowId, listener); + } + } + } + + /** + * Wraps a listener with a timeout cancellation listener to cancel the timeout task when the workflow completes. + * + * @param listener The original listener to wrap. + * @param scheduledCancellable The cancellable timeout task. + * @param isResponseSent An atomic boolean to ensure the response is sent only once. + * @param The type of the response expected by the listener. + * @return A wrapped ActionListener with timeout cancellation logic. + */ + public static ActionListener wrapWithTimeoutCancellationListener( + ActionListener listener, + Scheduler.ScheduledCancellable scheduledCancellable, + AtomicBoolean isResponseSent + ) { + return new ActionListener<>() { + @Override + public void onResponse(Response response) { + // Cancel the timeout task if the response is successfully sent. + if (isResponseSent.compareAndSet(false, true)) { + scheduledCancellable.cancel(); + } + listener.onResponse(response); + } + + @Override + public void onFailure(Exception e) { + // Cancel the timeout task if an error occurs and the failure is reported. + if (isResponseSent.compareAndSet(false, true)) { + scheduledCancellable.cancel(); + } + listener.onFailure(e); + } + }; + } + + /** + * Handles the successful completion of a workflow. + * + * @param workflowId The unique identifier of the workflow. + * @param workflowResponse The response from the workflow execution. + * @param isResponseSent An atomic boolean to ensure the response is sent only once. + * @param listener The listener to notify of the workflow completion. + */ + public static void handleResponse( + String workflowId, + WorkflowResponse workflowResponse, + AtomicBoolean isResponseSent, + ActionListener listener + ) { + // Check if the response has already been sent, and send it only if it hasn't been sent yet. + if (isResponseSent.compareAndSet(false, true)) { + listener.onResponse(new WorkflowResponse(workflowResponse.getWorkflowId(), workflowResponse.getWorkflowState())); + } else { + logger.info("Ignoring onResponse for workflowId: {} as timeout already occurred", workflowId); + } + } + + /** + * Handles the failure of a workflow execution. + * + * @param workflowId The unique identifier of the workflow. + * @param e The exception that occurred during workflow execution. + * @param isResponseSent An atomic boolean to ensure the response is sent only once. + * @param listener The listener to notify of the workflow failure. + */ + public static void handleFailure( + String workflowId, + Exception e, + AtomicBoolean isResponseSent, + ActionListener listener + ) { + // Check if the failure has already been reported, and report it only if it hasn't been reported yet. + if (isResponseSent.compareAndSet(false, true)) { + FlowFrameworkException exception = new FlowFrameworkException( + "Failed to execute workflow " + workflowId, + ExceptionsHelper.status(e) + ); + listener.onFailure(exception); + } else { + logger.info("Ignoring onFailure for workflowId: {} as timeout already occurred", workflowId); + } + } + + /** + * Fetches the workflow state after a timeout has occurred. + * This method sends a request to retrieve the current state of the workflow + * and notifies the listener with the updated state or an error if the request fails. + * + * @param client The OpenSearch client used to fetch the workflow state. + * @param workflowId The unique identifier of the workflow. + * @param listener The listener to notify with the updated state or failure. + */ + public static void fetchWorkflowStateAfterTimeout( + final Client client, + final String workflowId, + final ActionListener listener + ) { + logger.info("Fetching workflow state after timeout"); + client.execute( + GetWorkflowStateAction.INSTANCE, + new GetWorkflowStateRequest(workflowId, false), + ActionListener.wrap( + response -> listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())), + exception -> listener.onFailure( + new FlowFrameworkException("Failed to get workflow state after timeout", ExceptionsHelper.status(exception)) + ) + ) + ); + } +} diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index f6b1a5fc..747de435 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -39,6 +39,7 @@ import static org.opensearch.flowframework.common.CommonValue.REPROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; import static org.opensearch.flowframework.common.CommonValue.USE_CASE; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -128,6 +129,26 @@ public void testCreateWorkflowRequestWithParamsAndProvision() throws Exception { assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123")); } + public void testRestCreateWorkflowWithWaitForCompletionTimeout() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withParams(Map.ofEntries(Map.entry(PROVISION_WORKFLOW, "true"), Map.entry("wait_for_completion_timeout", "5s"))) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(new WorkflowResponse("workflow_1")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + + assertEquals(RestStatus.CREATED, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_1")); + } + public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) @@ -142,6 +163,23 @@ public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception ); } + public void testCreateWorkflowRequestWithWaitForTimeCompletionTimeoutButNoProvision() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.createWorkflowPath) + .withParams(Map.of(WAIT_FOR_COMPLETION_TIMEOUT, "1s")) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue( + channel.capturedResponse() + .content() + .utf8ToString() + .contains("are not allowed unless the 'provision' or 'reprovision' parameter is set to true.") + ); + } + public void testCreateWorkflowRequestWithUpdateAndProvision() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) diff --git a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java index fd5cd478..625e48e3 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java @@ -144,4 +144,26 @@ public void testFeatureFlagNotEnabled() throws Exception { assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status()); assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled.")); } + + public void testProvisionWorkflowWithValidWaitForCompletionTimeout() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.provisionWorkflowPath) + .withParams(Map.of("workflow_id", "abc", "wait_for_completion_timeout", "5s")) + .withContent(new BytesArray("{\"foo\": \"bar\"}"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(new WorkflowResponse("workflow_1")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + + assertEquals(RestStatus.OK, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_1")); + } + } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index ba76bc83..a3868875 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -38,6 +38,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.plugins.PluginsService; import org.opensearch.search.SearchHit; @@ -48,6 +49,7 @@ import org.opensearch.transport.TransportService; import java.io.IOException; +import java.time.Instant; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -62,6 +64,7 @@ import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.opensearch.flowframework.common.WorkflowResources.CREATE_CONNECTOR; @@ -252,7 +255,7 @@ public void testMaxWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false,null); doAnswer(invocation -> { ActionListener searchListener = invocation.getArgument(1); @@ -289,7 +292,15 @@ public void onFailure(Exception e) { public void testFailedToCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -320,7 +331,15 @@ public void testFailedToCreateNewWorkflow() { public void testCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -384,7 +403,15 @@ public void testCreateWithUserAndFilterOn() { ); ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -448,7 +475,15 @@ public void testFailedToCreateNewWorkflowWithNullUser() { ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); createWorkflowTransportAction1.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -483,7 +518,15 @@ public void testFailedToCreateNewWorkflowWithNoBackendRoleUser() { ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); createWorkflowTransportAction1.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -497,7 +540,15 @@ public void testFailedToCreateNewWorkflowWithNoBackendRoleUser() { public void testUpdateWorkflowWithReprovision() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest("1", template, new String[] { "off" }, false, Collections.emptyMap(), true); + WorkflowRequest workflowRequest = new WorkflowRequest( + "1", + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + true, + null + ); doAnswer(invocation -> { ActionListener getListener = invocation.getArgument(1); @@ -541,7 +592,15 @@ public void testUpdateWorkflowWithReprovision() throws IOException { public void testFailedToUpdateWorkflowWithReprovision() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest("1", template, new String[] { "off" }, false, Collections.emptyMap(), true); + WorkflowRequest workflowRequest = new WorkflowRequest( + "1", + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + true, + null + ); doAnswer(invocation -> { ActionListener getListener = invocation.getArgument(1); @@ -841,6 +900,68 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc new String[] { "all" }, true, Collections.emptyMap(), + false, + null + ); + + // Bypass checkMaxWorkflows and force onResponse + doAnswer(invocation -> { + ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + checkMaxWorkflowListener.onResponse(true); + return null; + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + + // Bypass putTemplateToGlobalContext and force onResponse + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(), any()); + + // Bypass putInitialStateToWorkflowState and force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new IndexResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + WorkflowResponse response = mock(WorkflowResponse.class); + when(response.getWorkflowId()).thenReturn("1"); + responseListener.onResponse(response); + return null; + }).when(client).execute(eq(ProvisionWorkflowAction.INSTANCE), any(WorkflowRequest.class), any(ActionListener.class)); + + ArgumentCaptor workflowResponseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + + verify(listener, times(1)).onResponse(workflowResponseCaptor.capture()); + assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); + } + + public void testCreateWorkflow_withValidation_withWaitForCompletion_withProvision_Success() throws Exception { + + Template validTemplate = generateValidTemplate(); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + doNothing().when(workflowProcessSorter).validate(any(), any()); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + validTemplate, + new String[] { "all" }, + true, + Map.of(WAIT_FOR_COMPLETION_TIMEOUT, "5s"), false ); @@ -876,6 +997,19 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc ActionListener responseListener = invocation.getArgument(2); WorkflowResponse response = mock(WorkflowResponse.class); when(response.getWorkflowId()).thenReturn("1"); + when(response.getWorkflowState()).thenReturn( + new WorkflowState( + "1", + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ) + ); responseListener.onResponse(response); return null; }).when(client).execute(eq(ProvisionWorkflowAction.INSTANCE), any(WorkflowRequest.class), any(ActionListener.class)); @@ -886,6 +1020,82 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc verify(listener, times(1)).onResponse(workflowResponseCaptor.capture()); assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); + assertEquals("PROVISIONING", workflowResponseCaptor.getValue().getWorkflowState().getState()); + } + + public void testCreateWorkflow_withValidation_withWaitForCompletionTimeSetZero_withProvision_Success() throws Exception { + + Template validTemplate = generateValidTemplate(); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + doNothing().when(workflowProcessSorter).validate(any(), any()); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + validTemplate, + new String[] { "all" }, + true, + Map.of(WAIT_FOR_COMPLETION_TIMEOUT, "0s"), + false + ); + + // Bypass checkMaxWorkflows and force onResponse + doAnswer(invocation -> { + ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + checkMaxWorkflowListener.onResponse(true); + return null; + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + + // Bypass putTemplateToGlobalContext and force onResponse + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(), any()); + + // Bypass putInitialStateToWorkflowState and force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new IndexResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + WorkflowResponse response = mock(WorkflowResponse.class); + when(response.getWorkflowId()).thenReturn("1"); + when(response.getWorkflowState()).thenReturn( + new WorkflowState( + "1", + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ) + ); + responseListener.onResponse(response); + return null; + }).when(client).execute(eq(ProvisionWorkflowAction.INSTANCE), any(WorkflowRequest.class), any(ActionListener.class)); + + ArgumentCaptor workflowResponseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + + verify(listener, times(1)).onResponse(workflowResponseCaptor.capture()); + assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); + assertEquals("PROVISIONING", workflowResponseCaptor.getValue().getWorkflowState().getState()); } public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning() throws Exception { @@ -901,7 +1111,8 @@ public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning() new String[] { "all" }, true, Collections.emptyMap(), - false + false, + null ); // Bypass checkMaxWorkflows and force onResponse diff --git a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequestTests.java b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequestTests.java index 2448d937..f122d271 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequestTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowRequestTests.java @@ -10,6 +10,7 @@ import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.BytesStreamInput; import org.opensearch.flowframework.TestHelpers; @@ -72,7 +73,7 @@ public void setUp() throws Exception { } public void testReprovisionWorkflowRequest() throws IOException { - ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest("123", originalTemplate, updatedTemplate); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest("123", originalTemplate, updatedTemplate, TimeValue.MINUS_ONE); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); diff --git a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java index 6e1e65d3..e13b1f0d 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java @@ -14,6 +14,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.common.FlowFrameworkSettings; @@ -152,7 +153,7 @@ public void testReprovisionWorkflow() throws Exception { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate, TimeValue.MINUS_ONE); reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); @@ -189,7 +190,7 @@ public void testReprovisionProvisioningWorkflow() throws Exception { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate, TimeValue.MINUS_ONE); reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -229,7 +230,7 @@ public void testReprovisionNotStartedWorkflow() throws Exception { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate, TimeValue.MINUS_ONE); reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -280,7 +281,7 @@ public void testFailedStateUpdate() throws Exception { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate, TimeValue.MINUS_ONE); reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -312,7 +313,7 @@ public void testFailedWorkflowStateRetrieval() throws Exception { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate); + ReprovisionWorkflowRequest request = new ReprovisionWorkflowRequest(workflowId, mockTemplate, mockTemplate, TimeValue.MINUS_ONE); reprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java index e92255e0..50c60a19 100644 --- a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java @@ -21,9 +21,11 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; +import java.time.Instant; import java.util.Collections; import java.util.List; import java.util.Map; @@ -156,7 +158,7 @@ public void testWorkflowRequestWithParams() throws IOException { public void testWorkflowRequestWithParamsNoProvision() throws IOException { IllegalArgumentException ex = assertThrows( IllegalArgumentException.class, - () -> new WorkflowRequest("123", template, new String[] { "all" }, false, Map.of("foo", "bar"), false) + () -> new WorkflowRequest("123", template, new String[] { "all" }, false, Map.of("foo", "bar"), false, null) ); assertEquals("Params may only be included when provisioning.", ex.getMessage()); } @@ -168,7 +170,8 @@ public void testWorkflowRequestWithOnlyUpdateParamNoProvision() throws IOExcepti new String[] { "all" }, true, Map.of(UPDATE_WORKFLOW_FIELDS, "true"), - false + false, + null ); assertNotNull(workflowRequest.getWorkflowId()); assertEquals(template, workflowRequest.getTemplate()); @@ -208,4 +211,37 @@ public void testWorkflowResponse() throws IOException { assertEquals("{\"workflow_id\":\"123\"}", builder.toString()); } + public void testWorkflowResponseWithWaitForCompletionTimeOut() throws IOException { + WorkflowState workFlowState = new WorkflowState( + "123", + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ); + + WorkflowResponse response = new WorkflowResponse("123", workFlowState); + assertEquals("123", response.getWorkflowId()); + assertEquals("PROVISIONING", response.getWorkflowState().getState()); + + BytesStreamOutput out = new BytesStreamOutput(); + response.writeTo(out); + BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes())); + WorkflowResponse streamInputResponse = new WorkflowResponse(in); + + assertEquals(response.getWorkflowId(), streamInputResponse.getWorkflowId()); + assertEquals(response.getWorkflowState().getState(), streamInputResponse.getWorkflowState().getState()); + + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + + assertNotNull(builder); + assertTrue(builder.toString().contains("\"workflow_id\":\"123\"")); + assertTrue(builder.toString().contains("\"state\":\"PROVISIONING\"")); + } + } diff --git a/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java b/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java new file mode 100644 index 00000000..d7dcaeda --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/util/WorkflowTimeoutUtilityTests.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.transport.GetWorkflowStateAction; +import org.opensearch.flowframework.transport.GetWorkflowStateRequest; +import org.opensearch.flowframework.transport.WorkflowResponse; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.Scheduler; +import org.opensearch.threadpool.ThreadPool; + +import java.time.Instant; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class WorkflowTimeoutUtilityTests extends OpenSearchTestCase { + + private Client mockClient; + private ThreadPool mockThreadPool; + private Scheduler.ScheduledCancellable mockScheduledCancellable; + private AtomicBoolean isResponseSent; + private ActionListener mockListener; + + @Override + public void setUp() throws Exception { + super.setUp(); + mockClient = mock(Client.class); + mockThreadPool = mock(ThreadPool.class); + mockScheduledCancellable = mock(Scheduler.ScheduledCancellable.class); + isResponseSent = new AtomicBoolean(false); + mockListener = mock(ActionListener.class); + + when(mockThreadPool.schedule(any(Runnable.class), any(TimeValue.class), anyString())).thenReturn(mockScheduledCancellable); + } + + public void testScheduleTimeoutHandler() { + String workflowId = "testWorkflowId"; + long timeout = 1000L; + + ActionListener returnedListener = WorkflowTimeoutUtility.scheduleTimeoutHandler( + mockClient, + mockThreadPool, + workflowId, + mockListener, + timeout, + isResponseSent + ); + + assertNotNull(returnedListener); + verify(mockThreadPool, times(1)).schedule( + any(Runnable.class), + eq(TimeValue.timeValueMillis(timeout)), + eq(PROVISION_WORKFLOW_THREAD_POOL) + ); + } + + public void testWrapWithTimeoutCancellationListenerOnResponse() { + WorkflowResponse response = new WorkflowResponse( + "testWorkflowId", + new WorkflowState( + "1", + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ) + ); + Scheduler.ScheduledCancellable scheduledCancellable = mock(Scheduler.ScheduledCancellable.class); + + ActionListener wrappedListener = WorkflowTimeoutUtility.wrapWithTimeoutCancellationListener( + mockListener, + scheduledCancellable, + isResponseSent + ); + + wrappedListener.onResponse(response); + + assertTrue(isResponseSent.get()); + verify(scheduledCancellable, times(1)).cancel(); + verify(mockListener, times(1)).onResponse(response); + } + + public void testWrapWithTimeoutCancellationListenerOnFailure() { + Exception exception = new Exception("Test exception"); + Scheduler.ScheduledCancellable scheduledCancellable = mock(Scheduler.ScheduledCancellable.class); + + ActionListener wrappedListener = WorkflowTimeoutUtility.wrapWithTimeoutCancellationListener( + mockListener, + scheduledCancellable, + isResponseSent + ); + + wrappedListener.onFailure(exception); + + assertTrue(isResponseSent.get()); + verify(scheduledCancellable, times(1)).cancel(); + verify(mockListener, times(1)).onFailure(exception); + } + + public void testFetchWorkflowStateAfterTimeout() { + String workflowId = "testWorkflowId"; + ActionListener mockListener = mock(ActionListener.class); + + WorkflowTimeoutUtility.fetchWorkflowStateAfterTimeout(mockClient, workflowId, mockListener); + + verify(mockClient, times(1)).execute( + eq(GetWorkflowStateAction.INSTANCE), + any(GetWorkflowStateRequest.class), + any(ActionListener.class) + ); + } +}