Skip to content

Commit

Permalink
Passing tenantId to MLClient for Connector, Model and ModelGroup
Browse files Browse the repository at this point in the history
Signed-off-by: Siddhartha Bingi <sidbingi@amazon.com>
  • Loading branch information
Siddhartha Bingi committed Jan 23, 2025
1 parent 1c16b1b commit 3f7a230
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ public void onFailure(Exception ex) {
.parameters(parameters)
.credential(credentials)
.actions(actions)
.tenantId(tenantId)
.build();

mlClient.createConnector(mlInput, actionListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public PlainActionFuture<WorkflowData> execute(
);
String connectorId = (String) inputs.get(CONNECTOR_ID);

mlClient.deleteConnector(connectorId, new ActionListener<>() {
mlClient.deleteConnector(connectorId, tenantId, new ActionListener<>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
deleteConnectorFuture.onResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public PlainActionFuture<WorkflowData> execute(

String modelId = inputs.get(MODEL_ID).toString();

mlClient.deleteModel(modelId, new ActionListener<>() {
mlClient.deleteModel(modelId, tenantId, new ActionListener<>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
deleteModelFuture.onResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ public void onFailure(Exception ex) {

MLRegisterModelGroupInputBuilder builder = MLRegisterModelGroupInput.builder();
builder.name(modelGroupName);
builder.tenantId(tenantId);
if (description != null) {
builder.description(description);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ public PlainActionFuture<WorkflowData> execute(
MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder()
.functionName(FunctionName.REMOTE)
.modelName(modelName)
.connectorId(connectorId);
.connectorId(connectorId)
.tenantId(tenantId);

if (modelGroupId != null) {
builder.modelGroupId(modelGroupId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;

Expand All @@ -54,12 +55,12 @@ public void testDeleteConnector() throws IOException, ExecutionException, Interr

doAnswer(invocation -> {
String connectorIdArg = invocation.getArgument(0);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
DeleteResponse output = new DeleteResponse(shardId, connectorIdArg, 1, 1, 1, true);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).deleteConnector(anyString(), anyActionListener());
}).when(machineLearningNodeClient).deleteConnector(anyString(), nullable(String.class), anyActionListener());

PlainActionFuture<WorkflowData> future = deleteConnectorStep.execute(
inputData.getNodeId(),
Expand All @@ -69,7 +70,7 @@ public void testDeleteConnector() throws IOException, ExecutionException, Interr
Collections.emptyMap(),
null
);
verify(machineLearningNodeClient).deleteConnector(anyString(), anyActionListener());
verify(machineLearningNodeClient).deleteConnector(anyString(), nullable(String.class), anyActionListener());

assertTrue(future.isDone());
assertEquals(connectorId, future.get().getContent().get(CONNECTOR_ID));
Expand Down Expand Up @@ -97,10 +98,10 @@ public void testDeleteConnectorFailure() throws IOException {
DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient);

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new FlowFrameworkException("Failed to delete connector", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).deleteConnector(anyString(), anyActionListener());
}).when(machineLearningNodeClient).deleteConnector(anyString(), nullable(String.class), anyActionListener());

PlainActionFuture<WorkflowData> future = deleteConnectorStep.execute(
inputData.getNodeId(),
Expand All @@ -111,7 +112,7 @@ public void testDeleteConnectorFailure() throws IOException {
null
);

verify(machineLearningNodeClient).deleteConnector(anyString(), anyActionListener());
verify(machineLearningNodeClient).deleteConnector(anyString(), nullable(String.class), anyActionListener());

assertTrue(future.isDone());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;

Expand All @@ -54,12 +55,12 @@ public void testDeleteModel() throws IOException, ExecutionException, Interrupte

doAnswer(invocation -> {
String modelIdArg = invocation.getArgument(0);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
DeleteResponse output = new DeleteResponse(shardId, modelIdArg, 1, 1, 1, true);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).deleteModel(any(String.class), any());
}).when(machineLearningNodeClient).deleteModel(any(String.class), nullable(String.class), any());

PlainActionFuture<WorkflowData> future = deleteModelStep.execute(
inputData.getNodeId(),
Expand All @@ -69,7 +70,7 @@ public void testDeleteModel() throws IOException, ExecutionException, Interrupte
Collections.emptyMap(),
null
);
verify(machineLearningNodeClient).deleteModel(any(String.class), any());
verify(machineLearningNodeClient).deleteModel(any(String.class), nullable(String.class), any());

assertTrue(future.isDone());
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
Expand All @@ -81,10 +82,10 @@ public void testDeleteModelNotFound() throws IOException, ExecutionException, In
DeleteModelStep deleteModelStep = new DeleteModelStep(machineLearningNodeClient);

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new OpenSearchStatusException("No model found with that id", RestStatus.NOT_FOUND));
return null;
}).when(machineLearningNodeClient).deleteModel(any(String.class), any());
}).when(machineLearningNodeClient).deleteModel(any(String.class), nullable(String.class), any());

PlainActionFuture<WorkflowData> future = deleteModelStep.execute(
inputData.getNodeId(),
Expand All @@ -94,7 +95,7 @@ public void testDeleteModelNotFound() throws IOException, ExecutionException, In
Collections.emptyMap(),
null
);
verify(machineLearningNodeClient).deleteModel(any(String.class), any());
verify(machineLearningNodeClient).deleteModel(any(String.class), nullable(String.class), any());

assertTrue(future.isDone());
assertEquals(modelId, future.get().getContent().get(MODEL_ID));
Expand Down Expand Up @@ -122,10 +123,10 @@ public void testDeleteModelFailure() throws IOException {
DeleteModelStep deleteModelStep = new DeleteModelStep(machineLearningNodeClient);

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new FlowFrameworkException("Failed to delete model", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).deleteModel(any(String.class), any());
}).when(machineLearningNodeClient).deleteModel(any(String.class), nullable(String.class), any());

PlainActionFuture<WorkflowData> future = deleteModelStep.execute(
inputData.getNodeId(),
Expand All @@ -136,7 +137,7 @@ public void testDeleteModelFailure() throws IOException {
null
);

verify(machineLearningNodeClient).deleteModel(any(String.class), any());
verify(machineLearningNodeClient).deleteModel(any(String.class), nullable(String.class), any());

assertTrue(future.isDone());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
Expand Down

0 comments on commit 3f7a230

Please sign in to comment.