Skip to content

Commit

Permalink
Passing tenantId to MLClient for Model and Agent Steps (#1028)
Browse files Browse the repository at this point in the history
Passing tenantId to MLClient for DeployModel, UndeployModel and Agent Steps

Signed-off-by: Siddhartha Bingi <sidbingi@amazon.com>
Co-authored-by: Siddhartha Bingi <sidbingi@amazon.com>
(cherry picked from commit 9108b2e)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and Siddhartha Bingi committed Jan 28, 2025
1 parent 82bfbc3 commit 37f4028
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public PlainActionFuture<WorkflowData> execute(
);
String agentId = (String) inputs.get(AGENT_ID);

mlClient.deleteAgent(agentId, new ActionListener<>() {
mlClient.deleteAgent(agentId, tenantId, new ActionListener<>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
deleteAgentFuture.onResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public PlainActionFuture<WorkflowData> execute(

String modelId = (String) inputs.get(MODEL_ID);

mlClient.deploy(modelId, new ActionListener<>() {
mlClient.deploy(modelId, tenantId, new ActionListener<>() {
@Override
public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
logger.info("Model deployment state {}", mlDeployModelResponse.getStatus());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ public void onFailure(Exception ex) {
.parameters(parametersMap)
.createdTime(createdTime)
.lastUpdateTime(lastUpdateTime)
.appType(appType);
.appType(appType)
.tenantId(tenantId);

MLAgent mlAgent = builder.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public PlainActionFuture<WorkflowData> execute(

String modelId = inputs.get(MODEL_ID).toString();

mlClient.undeploy(new String[] { modelId }, null, new ActionListener<>() {
mlClient.undeploy(new String[] { modelId }, null, tenantId, new ActionListener<>() {
@Override
public void onResponse(MLUndeployModelsResponse mlUndeployModelsResponse) {
List<FailedNodeException> failures = mlUndeployModelsResponse.getResponse().failures();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;

Expand All @@ -54,12 +55,12 @@ public void testDeleteAgent() throws IOException, ExecutionException, Interrupte

doAnswer(invocation -> {
String agentIdArg = invocation.getArgument(0);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
DeleteResponse output = new DeleteResponse(shardId, agentIdArg, 1, 1, 1, true);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).deleteAgent(any(String.class), any());
}).when(machineLearningNodeClient).deleteAgent(any(String.class), nullable(String.class), any());

PlainActionFuture<WorkflowData> future = deleteAgentStep.execute(
inputData.getNodeId(),
Expand All @@ -69,7 +70,7 @@ public void testDeleteAgent() throws IOException, ExecutionException, Interrupte
Collections.emptyMap(),
null
);
verify(machineLearningNodeClient).deleteAgent(any(String.class), any());
verify(machineLearningNodeClient).deleteAgent(any(String.class), nullable(String.class), any());

assertTrue(future.isDone());
assertEquals(agentId, future.get().getContent().get(AGENT_ID));
Expand All @@ -81,10 +82,10 @@ public void testDeleteAgentNotFound() throws IOException, ExecutionException, In
DeleteAgentStep deleteAgentStep = new DeleteAgentStep(machineLearningNodeClient);

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new OpenSearchStatusException("No agent found with that id", RestStatus.NOT_FOUND));
return null;
}).when(machineLearningNodeClient).deleteAgent(any(String.class), any());
}).when(machineLearningNodeClient).deleteAgent(any(String.class), nullable(String.class), any());

PlainActionFuture<WorkflowData> future = deleteAgentStep.execute(
inputData.getNodeId(),
Expand All @@ -94,7 +95,7 @@ public void testDeleteAgentNotFound() throws IOException, ExecutionException, In
Collections.emptyMap(),
null
);
verify(machineLearningNodeClient).deleteAgent(any(String.class), any());
verify(machineLearningNodeClient).deleteAgent(any(String.class), nullable(String.class), any());

assertTrue(future.isDone());
assertEquals(agentId, future.get().getContent().get(AGENT_ID));
Expand Down Expand Up @@ -122,10 +123,10 @@ public void testDeleteAgentFailure() throws IOException {
DeleteAgentStep deleteAgentStep = new DeleteAgentStep(machineLearningNodeClient);

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new FlowFrameworkException("Failed to delete agent", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).deleteAgent(any(String.class), any());
}).when(machineLearningNodeClient).deleteAgent(any(String.class), nullable(String.class), any());

PlainActionFuture<WorkflowData> future = deleteAgentStep.execute(
inputData.getNodeId(),
Expand All @@ -136,7 +137,7 @@ public void testDeleteAgentFailure() throws IOException {
null
);

verify(machineLearningNodeClient).deleteAgent(any(String.class), any());
verify(machineLearningNodeClient).deleteAgent(any(String.class), nullable(String.class), any());

assertTrue(future.isDone());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -118,11 +119,11 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I
ArgumentCaptor<ActionListener<MLDeployModelResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(1);
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).deploy(eq(modelId), actionListenerCaptor.capture());
}).when(machineLearningNodeClient).deploy(eq(modelId), nullable(String.class), actionListenerCaptor.capture());

// Stub getTask for success case
doAnswer(invocation -> {
Expand Down Expand Up @@ -150,7 +151,7 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I

future.actionGet();

verify(machineLearningNodeClient, times(1)).deploy(any(String.class), any());
verify(machineLearningNodeClient, times(1)).deploy(any(String.class), nullable(String.class), any());
verify(machineLearningNodeClient, times(1)).getTask(any(), any());

assertEquals(modelId, future.get().getContent().get(MODEL_ID));
Expand All @@ -162,10 +163,10 @@ public void testDeployModelFailure() {
ArgumentCaptor<ActionListener<MLDeployModelResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(1);
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new FlowFrameworkException("Failed to deploy model", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture());
}).when(machineLearningNodeClient).deploy(eq("modelId"), nullable(String.class), actionListenerCaptor.capture());

PlainActionFuture<WorkflowData> future = deployModel.execute(
inputData.getNodeId(),
Expand All @@ -176,7 +177,7 @@ public void testDeployModelFailure() {
null
);

verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture());
verify(machineLearningNodeClient).deploy(eq("modelId"), nullable(String.class), actionListenerCaptor.capture());

ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
Expand All @@ -194,11 +195,11 @@ public void testDeployModelTaskFailure() throws IOException, InterruptedExceptio
ArgumentCaptor<ActionListener<MLDeployModelResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(1);
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).deploy(eq(modelId), actionListenerCaptor.capture());
}).when(machineLearningNodeClient).deploy(eq(modelId), nullable(String.class), actionListenerCaptor.capture());

// Stub getTask for success case
doAnswer(invocation -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import static org.opensearch.flowframework.common.CommonValue.SUCCESS;
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;

Expand All @@ -57,7 +58,7 @@ public void testUndeployModel() throws IOException, ExecutionException, Interrup

doAnswer(invocation -> {
ClusterName clusterName = new ClusterName("clusterName");
ActionListener<MLUndeployModelsResponse> actionListener = invocation.getArgument(2);
ActionListener<MLUndeployModelsResponse> actionListener = invocation.getArgument(3);
MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse(
clusterName,
Collections.emptyList(),
Expand All @@ -66,7 +67,7 @@ public void testUndeployModel() throws IOException, ExecutionException, Interrup
MLUndeployModelsResponse output = new MLUndeployModelsResponse(mlUndeployModelNodesResponse);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any());
}).when(machineLearningNodeClient).undeploy(any(String[].class), any(), nullable(String.class), any());

PlainActionFuture<WorkflowData> future = UndeployModelStep.execute(
inputData.getNodeId(),
Expand All @@ -76,7 +77,7 @@ public void testUndeployModel() throws IOException, ExecutionException, Interrup
Collections.emptyMap(),
null
);
verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any());
verify(machineLearningNodeClient).undeploy(any(String[].class), any(), nullable(String.class), any());

assertTrue(future.isDone());
assertTrue((boolean) future.get().getContent().get(SUCCESS));
Expand Down Expand Up @@ -105,7 +106,7 @@ public void testUndeployModelFailure() throws IOException {

doAnswer(invocation -> {
ClusterName clusterName = new ClusterName("clusterName");
ActionListener<MLUndeployModelsResponse> actionListener = invocation.getArgument(2);
ActionListener<MLUndeployModelsResponse> actionListener = invocation.getArgument(3);
MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse(
clusterName,
Collections.emptyList(),
Expand All @@ -116,7 +117,7 @@ public void testUndeployModelFailure() throws IOException {

actionListener.onFailure(new FlowFrameworkException("Failed to undeploy model", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any());
}).when(machineLearningNodeClient).undeploy(any(String[].class), any(), nullable(String.class), any());

PlainActionFuture<WorkflowData> future = UndeployModelStep.execute(
inputData.getNodeId(),
Expand All @@ -127,7 +128,7 @@ public void testUndeployModelFailure() throws IOException {
null
);

verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any());
verify(machineLearningNodeClient).undeploy(any(String[].class), any(), nullable(String.class), any());

assertTrue(future.isDone());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
Expand Down

0 comments on commit 37f4028

Please sign in to comment.